mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
Merge pull request #620 from goasleep/feature/export_search_engine
feat:expose the search engine params to user
This commit is contained in:
commit
1bcc0bff0a
@ -2,9 +2,10 @@
|
||||
CSVScraperMultiGraph Module
|
||||
"""
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from .base_graph import BaseGraph
|
||||
from .abstract_graph import AbstractGraph
|
||||
from .csv_scraper_graph import CSVScraperGraph
|
||||
@ -12,6 +13,7 @@ from ..nodes import (
|
||||
GraphIteratorNode,
|
||||
MergeAnswersNode
|
||||
)
|
||||
from ..utils.copy import safe_deepcopy
|
||||
|
||||
class CSVScraperMultiGraph(AbstractGraph):
|
||||
"""
|
||||
@ -46,10 +48,7 @@ class CSVScraperMultiGraph(AbstractGraph):
|
||||
|
||||
self.max_results = config.get("max_results", 3)
|
||||
|
||||
if all(isinstance(value, str) for value in config.values()):
|
||||
self.copy_config = copy(config)
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
self.copy_config = safe_deepcopy(config)
|
||||
|
||||
super().__init__(prompt, config, source, schema)
|
||||
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
JSONScraperMultiGraph Module
|
||||
"""
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .base_graph import BaseGraph
|
||||
from .abstract_graph import AbstractGraph
|
||||
from .json_scraper_graph import JSONScraperGraph
|
||||
@ -12,6 +13,7 @@ from ..nodes import (
|
||||
GraphIteratorNode,
|
||||
MergeAnswersNode
|
||||
)
|
||||
from ..utils.copy import safe_deepcopy
|
||||
|
||||
class JSONScraperMultiGraph(AbstractGraph):
|
||||
"""
|
||||
@ -45,10 +47,7 @@ class JSONScraperMultiGraph(AbstractGraph):
|
||||
|
||||
self.max_results = config.get("max_results", 3)
|
||||
|
||||
if all(isinstance(value, str) for value in config.values()):
|
||||
self.copy_config = copy(config)
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
self.copy_config = safe_deepcopy(config)
|
||||
|
||||
self.copy_schema = deepcopy(schema)
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from ..nodes import (
|
||||
GraphIteratorNode,
|
||||
MergeAnswersNode
|
||||
)
|
||||
from ..utils.copy import safe_deepcopy
|
||||
|
||||
class MDScraperMultiGraph(AbstractGraph):
|
||||
"""
|
||||
@ -42,11 +43,7 @@ class MDScraperMultiGraph(AbstractGraph):
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):
|
||||
if all(isinstance(value, str) for value in config.values()):
|
||||
self.copy_config = copy(config)
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
|
||||
self.copy_config = safe_deepcopy(config)
|
||||
self.copy_schema = deepcopy(schema)
|
||||
|
||||
super().__init__(prompt, config, source, schema)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
OmniSearchGraph Module
|
||||
"""
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -15,6 +15,7 @@ from ..nodes import (
|
||||
GraphIteratorNode,
|
||||
MergeAnswersNode
|
||||
)
|
||||
from ..utils.copy import safe_deepcopy
|
||||
|
||||
|
||||
class OmniSearchGraph(AbstractGraph):
|
||||
@ -48,10 +49,7 @@ class OmniSearchGraph(AbstractGraph):
|
||||
|
||||
self.max_results = config.get("max_results", 3)
|
||||
|
||||
if all(isinstance(value, str) for value in config.values()):
|
||||
self.copy_config = copy(config)
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
self.copy_config = safe_deepcopy(config)
|
||||
|
||||
self.copy_schema = deepcopy(schema)
|
||||
|
||||
@ -85,7 +83,8 @@ class OmniSearchGraph(AbstractGraph):
|
||||
output=["urls"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"max_results": self.max_results
|
||||
"max_results": self.max_results,
|
||||
"search_engine": self.copy_config.get("search_engine")
|
||||
}
|
||||
)
|
||||
graph_iterator_node = GraphIteratorNode(
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
PdfScraperMultiGraph Module
|
||||
"""
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from .base_graph import BaseGraph
|
||||
@ -12,6 +12,7 @@ from ..nodes import (
|
||||
GraphIteratorNode,
|
||||
MergeAnswersNode
|
||||
)
|
||||
from ..utils.copy import safe_deepcopy
|
||||
|
||||
class PdfScraperMultiGraph(AbstractGraph):
|
||||
"""
|
||||
@ -44,10 +45,7 @@ class PdfScraperMultiGraph(AbstractGraph):
|
||||
def __init__(self, prompt: str, source: List[str],
|
||||
config: dict, schema: Optional[BaseModel] = None):
|
||||
|
||||
if all(isinstance(value, str) for value in config.values()):
|
||||
self.copy_config = copy(config)
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
self.copy_config = safe_deepcopy(config)
|
||||
|
||||
self.copy_schema = deepcopy(schema)
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
ScriptCreatorMultiGraph Module
|
||||
"""
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -15,6 +14,7 @@ from ..nodes import (
|
||||
GraphIteratorNode,
|
||||
MergeGeneratedScriptsNode
|
||||
)
|
||||
from ..utils.copy import safe_deepcopy
|
||||
|
||||
class ScriptCreatorMultiGraph(AbstractGraph):
|
||||
"""
|
||||
@ -47,10 +47,7 @@ class ScriptCreatorMultiGraph(AbstractGraph):
|
||||
|
||||
self.max_results = config.get("max_results", 3)
|
||||
|
||||
if all(isinstance(value, str) for value in config.values()):
|
||||
self.copy_config = copy(config)
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
self.copy_config = safe_deepcopy(config)
|
||||
|
||||
super().__init__(prompt, config, source, schema)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
SearchGraph Module
|
||||
"""
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -15,6 +15,7 @@ from ..nodes import (
|
||||
GraphIteratorNode,
|
||||
MergeAnswersNode
|
||||
)
|
||||
from ..utils.copy import safe_deepcopy
|
||||
|
||||
class SearchGraph(AbstractGraph):
|
||||
"""
|
||||
@ -47,10 +48,7 @@ class SearchGraph(AbstractGraph):
|
||||
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
|
||||
self.max_results = config.get("max_results", 3)
|
||||
|
||||
if all(isinstance(value, str) for value in config.values()):
|
||||
self.copy_config = copy(config)
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
self.copy_config = safe_deepcopy(config)
|
||||
self.copy_schema = deepcopy(schema)
|
||||
self.considered_urls = [] # New attribute to store URLs
|
||||
|
||||
@ -78,7 +76,8 @@ class SearchGraph(AbstractGraph):
|
||||
output=["urls"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"max_results": self.max_results
|
||||
"max_results": self.max_results,
|
||||
"search_engine": self.copy_config.get("search_engine")
|
||||
}
|
||||
)
|
||||
graph_iterator_node = GraphIteratorNode(
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
SmartScraperMultiGraph Module
|
||||
"""
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -14,6 +14,7 @@ from ..nodes import (
|
||||
GraphIteratorNode,
|
||||
MergeAnswersNode
|
||||
)
|
||||
from ..utils.copy import safe_deepcopy
|
||||
|
||||
class SmartScraperMultiGraph(AbstractGraph):
|
||||
"""
|
||||
@ -48,10 +49,7 @@ class SmartScraperMultiGraph(AbstractGraph):
|
||||
|
||||
self.max_results = config.get("max_results", 3)
|
||||
|
||||
if all(isinstance(value, str) for value in config.values()):
|
||||
self.copy_config = copy(config)
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
self.copy_config = safe_deepcopy(config)
|
||||
|
||||
self.copy_schema = deepcopy(schema)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
XMLScraperMultiGraph Module
|
||||
"""
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -14,6 +14,7 @@ from ..nodes import (
|
||||
GraphIteratorNode,
|
||||
MergeAnswersNode
|
||||
)
|
||||
from ..utils.copy import safe_deepcopy
|
||||
|
||||
class XMLScraperMultiGraph(AbstractGraph):
|
||||
"""
|
||||
@ -46,10 +47,7 @@ class XMLScraperMultiGraph(AbstractGraph):
|
||||
def __init__(self, prompt: str, source: List[str],
|
||||
config: dict, schema: Optional[BaseModel] = None):
|
||||
|
||||
if all(isinstance(value, str) for value in config.values()):
|
||||
self.copy_config = copy(config)
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
self.copy_config = safe_deepcopy(config)
|
||||
|
||||
self.copy_schema = deepcopy(schema)
|
||||
|
||||
|
||||
@ -41,7 +41,11 @@ class SearchInternetNode(BaseNode):
|
||||
self.verbose = (
|
||||
False if node_config is None else node_config.get("verbose", False)
|
||||
)
|
||||
self.search_engine = node_config.get("search_engine", "google")
|
||||
self.search_engine = (
|
||||
node_config["search_engine"]
|
||||
if node_config.get("search_engine")
|
||||
else "google"
|
||||
)
|
||||
self.max_results = node_config.get("max_results", 3)
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
|
||||
75
scrapegraphai/utils/copy.py
Normal file
75
scrapegraphai/utils/copy.py
Normal file
@ -0,0 +1,75 @@
|
||||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
class DeepCopyError(Exception):
|
||||
"""Custom exception raised when an object cannot be deep-copied."""
|
||||
pass
|
||||
|
||||
def safe_deepcopy(obj: Any) -> Any:
|
||||
"""
|
||||
Attempts to create a deep copy of the object using `copy.deepcopy`
|
||||
whenever possible. If that fails, it falls back to custom deep copy
|
||||
logic. If that also fails, it raises a `DeepCopyError`.
|
||||
|
||||
Args:
|
||||
obj (Any): The object to be copied, which can be of any type.
|
||||
|
||||
Returns:
|
||||
Any: A deep copy of the object if possible; otherwise, a shallow
|
||||
copy if deep copying fails; if neither is possible, the original
|
||||
object is returned.
|
||||
Raises:
|
||||
DeepCopyError: If the object cannot be deep-copied or shallow-copied.
|
||||
"""
|
||||
|
||||
try:
|
||||
|
||||
# Try to use copy.deepcopy first
|
||||
return copy.deepcopy(obj)
|
||||
except (TypeError, AttributeError) as e:
|
||||
# If deepcopy fails, handle specific types manually
|
||||
|
||||
# Handle dictionaries
|
||||
if isinstance(obj, dict):
|
||||
new_obj = {}
|
||||
|
||||
for k, v in obj.items():
|
||||
new_obj[k] = safe_deepcopy(v)
|
||||
return new_obj
|
||||
|
||||
# Handle lists
|
||||
elif isinstance(obj, list):
|
||||
new_obj = []
|
||||
|
||||
for v in obj:
|
||||
new_obj.append(safe_deepcopy(v))
|
||||
return new_obj
|
||||
|
||||
# Handle tuples (immutable, but might contain mutable objects)
|
||||
elif isinstance(obj, tuple):
|
||||
new_obj = tuple(safe_deepcopy(v) for v in obj)
|
||||
|
||||
return new_obj
|
||||
|
||||
# Handle frozensets (immutable, but might contain mutable objects)
|
||||
elif isinstance(obj, frozenset):
|
||||
new_obj = frozenset(safe_deepcopy(v) for v in obj)
|
||||
return new_obj
|
||||
|
||||
# Handle objects with attributes
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# If an object cannot be deep copied, then the sub-properties of \
|
||||
# the object will not be analyzed and shallow copy will be used directly.
|
||||
try:
|
||||
return copy.copy(obj)
|
||||
except (TypeError, AttributeError):
|
||||
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
|
||||
|
||||
|
||||
# Attempt shallow copy as a fallback
|
||||
try:
|
||||
return copy.copy(obj)
|
||||
except (TypeError, AttributeError):
|
||||
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
|
||||
|
||||
186
tests/utils/copy_utils_test.py
Normal file
186
tests/utils/copy_utils_test.py
Normal file
@ -0,0 +1,186 @@
|
||||
import copy
|
||||
import pytest
|
||||
|
||||
# Assuming the custom_deepcopy function is imported or defined above this line
|
||||
from scrapegraphai.utils.copy import DeepCopyError, safe_deepcopy
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
class PydantObject(BaseModel):
|
||||
value: int
|
||||
|
||||
class NormalObject:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
self.nested = [1, 2, 3]
|
||||
|
||||
|
||||
class NonDeepcopyable:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
raise TypeError("Forcing shallow copy fallback")
|
||||
|
||||
|
||||
class WithoutDict:
|
||||
__slots__ = ["value"]
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
raise TypeError("Forcing shallow copy fallback")
|
||||
|
||||
def __copy__(self):
|
||||
return self
|
||||
|
||||
|
||||
class NonCopyableObject:
|
||||
__slots__ = ["value"]
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
raise TypeError("fail deep copy ")
|
||||
|
||||
def __copy__(self):
|
||||
raise TypeError("fail shallow copy")
|
||||
|
||||
|
||||
def test_deepcopy_simple_dict():
|
||||
original = {"a": 1, "b": 2, "c": [3, 4, 5]}
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj == original
|
||||
assert copy_obj is not original
|
||||
assert copy_obj["c"] is not original["c"]
|
||||
|
||||
|
||||
def test_deepcopy_simple_list():
|
||||
original = [1, 2, 3, [4, 5]]
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj == original
|
||||
assert copy_obj is not original
|
||||
assert copy_obj[3] is not original[3]
|
||||
|
||||
|
||||
def test_deepcopy_with_tuple():
|
||||
original = (1, 2, [3, 4])
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj == original
|
||||
assert copy_obj is not original
|
||||
assert copy_obj[2] is not original[2]
|
||||
|
||||
|
||||
def test_deepcopy_with_frozenset():
|
||||
original = frozenset([1, 2, 3, (4, 5)])
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj == original
|
||||
assert copy_obj is not original
|
||||
|
||||
|
||||
def test_deepcopy_with_object():
|
||||
original = NormalObject(10)
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj.value == original.value
|
||||
assert copy_obj is not original
|
||||
assert copy_obj.nested is not original.nested
|
||||
|
||||
|
||||
def test_deepcopy_with_custom_deepcopy_fallback():
|
||||
original = {"origin": NormalObject(10)}
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj is not original
|
||||
assert copy_obj["origin"].value == original["origin"].value
|
||||
|
||||
|
||||
def test_shallow_copy_fallback():
|
||||
original = {"origin": NonDeepcopyable(10)}
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj is not original
|
||||
assert copy_obj["origin"].value == original["origin"].value
|
||||
|
||||
|
||||
def test_circular_reference():
|
||||
original = []
|
||||
original.append(original)
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj is not original
|
||||
assert copy_obj[0] is copy_obj
|
||||
|
||||
|
||||
|
||||
|
||||
def test_deepcopy_object_without_dict():
|
||||
original = {"origin": WithoutDict(10)}
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj["origin"].value == original["origin"].value
|
||||
assert copy_obj is not original
|
||||
assert copy_obj["origin"] is original["origin"]
|
||||
assert (
|
||||
hasattr(copy_obj["origin"], "__dict__") is False
|
||||
) # Ensure __dict__ is not present
|
||||
|
||||
original = [WithoutDict(10)]
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj[0].value == original[0].value
|
||||
assert copy_obj is not original
|
||||
assert copy_obj[0] is original[0]
|
||||
|
||||
original = (WithoutDict(10),)
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj[0].value == original[0].value
|
||||
assert copy_obj is not original
|
||||
assert copy_obj[0] is original[0]
|
||||
|
||||
original_item = WithoutDict(10)
|
||||
original = set([original_item])
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj is not original
|
||||
copy_obj_item = copy_obj.pop()
|
||||
assert copy_obj_item.value == original_item.value
|
||||
assert copy_obj_item is original_item
|
||||
|
||||
original_item = WithoutDict(10)
|
||||
original = frozenset([original_item])
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj is not original
|
||||
copy_obj_item = list(copy_obj)[0]
|
||||
assert copy_obj_item.value == original_item.value
|
||||
assert copy_obj_item is original_item
|
||||
|
||||
def test_unhandled_type():
|
||||
with pytest.raises(DeepCopyError):
|
||||
original = {"origin": NonCopyableObject(10)}
|
||||
copy_obj = safe_deepcopy(original)
|
||||
|
||||
def test_client():
|
||||
llm_instance_config = {
|
||||
"model": "moonshot-v1-8k",
|
||||
"base_url": "https://api.moonshot.cn/v1",
|
||||
"moonshot_api_key": "xxx",
|
||||
}
|
||||
|
||||
from langchain_community.chat_models.moonshot import MoonshotChat
|
||||
|
||||
llm_model_instance = MoonshotChat(**llm_instance_config)
|
||||
copy_obj = safe_deepcopy(llm_model_instance)
|
||||
|
||||
assert copy_obj
|
||||
assert hasattr(copy_obj, 'callbacks')
|
||||
|
||||
def test_circular_reference_in_dict():
|
||||
original = {}
|
||||
original['self'] = original # Create a circular reference
|
||||
copy_obj = safe_deepcopy(original)
|
||||
|
||||
# Check that the copy is a different object
|
||||
assert copy_obj is not original
|
||||
# Check that the circular reference is maintained in the copy
|
||||
assert copy_obj['self'] is copy_obj
|
||||
|
||||
def test_with_pydantic():
|
||||
original = PydantObject(value=1)
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj.value == original.value
|
||||
assert copy_obj is not original
|
||||
Loading…
Reference in New Issue
Block a user