Merge pull request #620 from goasleep/feature/export_search_engine

feat:expose the search engine params to user
This commit is contained in:
Marco Vinciguerra 2024-09-02 11:27:59 +02:00 committed by GitHub
commit 1bcc0bff0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 297 additions and 48 deletions

View File

@ -2,9 +2,10 @@
CSVScraperMultiGraph Module
"""
from copy import copy, deepcopy
from typing import List, Optional
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from .csv_scraper_graph import CSVScraperGraph
@ -12,6 +13,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class CSVScraperMultiGraph(AbstractGraph):
"""
@ -46,10 +48,7 @@ class CSVScraperMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
super().__init__(prompt, config, source, schema)

View File

@ -2,9 +2,10 @@
JSONScraperMultiGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from .json_scraper_graph import JSONScraperGraph
@ -12,6 +13,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class JSONScraperMultiGraph(AbstractGraph):
"""
@ -45,10 +47,7 @@ class JSONScraperMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

View File

@ -12,6 +12,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class MDScraperMultiGraph(AbstractGraph):
"""
@ -42,11 +43,7 @@ class MDScraperMultiGraph(AbstractGraph):
"""
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)
super().__init__(prompt, config, source, schema)

View File

@ -2,7 +2,7 @@
OmniSearchGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import Optional
from pydantic import BaseModel
@ -15,6 +15,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class OmniSearchGraph(AbstractGraph):
@ -48,10 +49,7 @@ class OmniSearchGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)
@ -85,7 +83,8 @@ class OmniSearchGraph(AbstractGraph):
output=["urls"],
node_config={
"llm_model": self.llm_model,
"max_results": self.max_results
"max_results": self.max_results,
"search_engine": self.copy_config.get("search_engine")
}
)
graph_iterator_node = GraphIteratorNode(

View File

@ -2,7 +2,7 @@
PdfScraperMultiGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
from .base_graph import BaseGraph
@ -12,6 +12,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class PdfScraperMultiGraph(AbstractGraph):
"""
@ -44,10 +45,7 @@ class PdfScraperMultiGraph(AbstractGraph):
def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None):
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

View File

@ -2,7 +2,6 @@
ScriptCreatorMultiGraph Module
"""
from copy import copy, deepcopy
from typing import List, Optional
from pydantic import BaseModel
@ -15,6 +14,7 @@ from ..nodes import (
GraphIteratorNode,
MergeGeneratedScriptsNode
)
from ..utils.copy import safe_deepcopy
class ScriptCreatorMultiGraph(AbstractGraph):
"""
@ -47,10 +47,7 @@ class ScriptCreatorMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
super().__init__(prompt, config, source, schema)

View File

@ -2,7 +2,7 @@
SearchGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import Optional, List
from pydantic import BaseModel
@ -15,6 +15,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class SearchGraph(AbstractGraph):
"""
@ -47,10 +48,7 @@ class SearchGraph(AbstractGraph):
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)
self.considered_urls = [] # New attribute to store URLs
@ -78,7 +76,8 @@ class SearchGraph(AbstractGraph):
output=["urls"],
node_config={
"llm_model": self.llm_model,
"max_results": self.max_results
"max_results": self.max_results,
"search_engine": self.copy_config.get("search_engine")
}
)
graph_iterator_node = GraphIteratorNode(

View File

@ -2,7 +2,7 @@
SmartScraperMultiGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
@ -14,6 +14,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class SmartScraperMultiGraph(AbstractGraph):
"""
@ -48,10 +49,7 @@ class SmartScraperMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

View File

@ -2,7 +2,7 @@
XMLScraperMultiGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
@ -14,6 +14,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class XMLScraperMultiGraph(AbstractGraph):
"""
@ -46,10 +47,7 @@ class XMLScraperMultiGraph(AbstractGraph):
def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None):
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

View File

@ -41,7 +41,11 @@ class SearchInternetNode(BaseNode):
self.verbose = (
False if node_config is None else node_config.get("verbose", False)
)
self.search_engine = node_config.get("search_engine", "google")
self.search_engine = (
node_config["search_engine"]
if node_config.get("search_engine")
else "google"
)
self.max_results = node_config.get("max_results", 3)
def execute(self, state: dict) -> dict:

View File

@ -0,0 +1,75 @@
import copy
from typing import Any, Dict, Optional
from pydantic.v1 import BaseModel
class DeepCopyError(Exception):
"""Custom exception raised when an object cannot be deep-copied."""
pass
def safe_deepcopy(obj: Any) -> Any:
"""
Attempts to create a deep copy of the object using `copy.deepcopy`
whenever possible. If that fails, it falls back to custom deep copy
logic. If that also fails, it raises a `DeepCopyError`.
Args:
obj (Any): The object to be copied, which can be of any type.
Returns:
Any: A deep copy of the object if possible; otherwise, a shallow
copy if deep copying fails; if neither is possible, the original
object is returned.
Raises:
DeepCopyError: If the object cannot be deep-copied or shallow-copied.
"""
try:
# Try to use copy.deepcopy first
return copy.deepcopy(obj)
except (TypeError, AttributeError) as e:
# If deepcopy fails, handle specific types manually
# Handle dictionaries
if isinstance(obj, dict):
new_obj = {}
for k, v in obj.items():
new_obj[k] = safe_deepcopy(v)
return new_obj
# Handle lists
elif isinstance(obj, list):
new_obj = []
for v in obj:
new_obj.append(safe_deepcopy(v))
return new_obj
# Handle tuples (immutable, but might contain mutable objects)
elif isinstance(obj, tuple):
new_obj = tuple(safe_deepcopy(v) for v in obj)
return new_obj
# Handle frozensets (immutable, but might contain mutable objects)
elif isinstance(obj, frozenset):
new_obj = frozenset(safe_deepcopy(v) for v in obj)
return new_obj
# Handle objects with attributes
elif hasattr(obj, "__dict__"):
# If an object cannot be deep copied, then the sub-properties of \
# the object will not be analyzed and shallow copy will be used directly.
try:
return copy.copy(obj)
except (TypeError, AttributeError):
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
# Attempt shallow copy as a fallback
try:
return copy.copy(obj)
except (TypeError, AttributeError):
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e

View File

@ -0,0 +1,186 @@
import copy
import pytest
# Assuming the custom_deepcopy function is imported or defined above this line
from scrapegraphai.utils.copy import DeepCopyError, safe_deepcopy
from pydantic.v1 import BaseModel
class PydantObject(BaseModel):
value: int
class NormalObject:
def __init__(self, value):
self.value = value
self.nested = [1, 2, 3]
class NonDeepcopyable:
def __init__(self, value):
self.value = value
def __deepcopy__(self, memo):
raise TypeError("Forcing shallow copy fallback")
class WithoutDict:
__slots__ = ["value"]
def __init__(self, value):
self.value = value
def __deepcopy__(self, memo):
raise TypeError("Forcing shallow copy fallback")
def __copy__(self):
return self
class NonCopyableObject:
__slots__ = ["value"]
def __init__(self, value):
self.value = value
def __deepcopy__(self, memo):
raise TypeError("fail deep copy ")
def __copy__(self):
raise TypeError("fail shallow copy")
def test_deepcopy_simple_dict():
original = {"a": 1, "b": 2, "c": [3, 4, 5]}
copy_obj = safe_deepcopy(original)
assert copy_obj == original
assert copy_obj is not original
assert copy_obj["c"] is not original["c"]
def test_deepcopy_simple_list():
original = [1, 2, 3, [4, 5]]
copy_obj = safe_deepcopy(original)
assert copy_obj == original
assert copy_obj is not original
assert copy_obj[3] is not original[3]
def test_deepcopy_with_tuple():
original = (1, 2, [3, 4])
copy_obj = safe_deepcopy(original)
assert copy_obj == original
assert copy_obj is not original
assert copy_obj[2] is not original[2]
def test_deepcopy_with_frozenset():
original = frozenset([1, 2, 3, (4, 5)])
copy_obj = safe_deepcopy(original)
assert copy_obj == original
assert copy_obj is not original
def test_deepcopy_with_object():
original = NormalObject(10)
copy_obj = safe_deepcopy(original)
assert copy_obj.value == original.value
assert copy_obj is not original
assert copy_obj.nested is not original.nested
def test_deepcopy_with_custom_deepcopy_fallback():
original = {"origin": NormalObject(10)}
copy_obj = safe_deepcopy(original)
assert copy_obj is not original
assert copy_obj["origin"].value == original["origin"].value
def test_shallow_copy_fallback():
original = {"origin": NonDeepcopyable(10)}
copy_obj = safe_deepcopy(original)
assert copy_obj is not original
assert copy_obj["origin"].value == original["origin"].value
def test_circular_reference():
original = []
original.append(original)
copy_obj = safe_deepcopy(original)
assert copy_obj is not original
assert copy_obj[0] is copy_obj
def test_deepcopy_object_without_dict():
original = {"origin": WithoutDict(10)}
copy_obj = safe_deepcopy(original)
assert copy_obj["origin"].value == original["origin"].value
assert copy_obj is not original
assert copy_obj["origin"] is original["origin"]
assert (
hasattr(copy_obj["origin"], "__dict__") is False
) # Ensure __dict__ is not present
original = [WithoutDict(10)]
copy_obj = safe_deepcopy(original)
assert copy_obj[0].value == original[0].value
assert copy_obj is not original
assert copy_obj[0] is original[0]
original = (WithoutDict(10),)
copy_obj = safe_deepcopy(original)
assert copy_obj[0].value == original[0].value
assert copy_obj is not original
assert copy_obj[0] is original[0]
original_item = WithoutDict(10)
original = set([original_item])
copy_obj = safe_deepcopy(original)
assert copy_obj is not original
copy_obj_item = copy_obj.pop()
assert copy_obj_item.value == original_item.value
assert copy_obj_item is original_item
original_item = WithoutDict(10)
original = frozenset([original_item])
copy_obj = safe_deepcopy(original)
assert copy_obj is not original
copy_obj_item = list(copy_obj)[0]
assert copy_obj_item.value == original_item.value
assert copy_obj_item is original_item
def test_unhandled_type():
with pytest.raises(DeepCopyError):
original = {"origin": NonCopyableObject(10)}
copy_obj = safe_deepcopy(original)
def test_client():
llm_instance_config = {
"model": "moonshot-v1-8k",
"base_url": "https://api.moonshot.cn/v1",
"moonshot_api_key": "xxx",
}
from langchain_community.chat_models.moonshot import MoonshotChat
llm_model_instance = MoonshotChat(**llm_instance_config)
copy_obj = safe_deepcopy(llm_model_instance)
assert copy_obj
assert hasattr(copy_obj, 'callbacks')
def test_circular_reference_in_dict():
original = {}
original['self'] = original # Create a circular reference
copy_obj = safe_deepcopy(original)
# Check that the copy is a different object
assert copy_obj is not original
# Check that the circular reference is maintained in the copy
assert copy_obj['self'] is copy_obj
def test_with_pydantic():
original = PydantObject(value=1)
copy_obj = safe_deepcopy(original)
assert copy_obj.value == original.value
assert copy_obj is not original