From 4f4d091b825038c0e8d8ff1848aca7d1e10e986e Mon Sep 17 00:00:00 2001 From: smith peng Date: Sat, 31 Aug 2024 11:30:20 +0800 Subject: [PATCH 1/6] feat:add deepcopy tool --- scrapegraphai/utils/copy.py | 78 +++++++++++++++ tests/utils/copy_utils_test.py | 170 +++++++++++++++++++++++++++++++++ 2 files changed, 248 insertions(+) create mode 100644 scrapegraphai/utils/copy.py create mode 100644 tests/utils/copy_utils_test.py diff --git a/scrapegraphai/utils/copy.py b/scrapegraphai/utils/copy.py new file mode 100644 index 00000000..4ccfcbf1 --- /dev/null +++ b/scrapegraphai/utils/copy.py @@ -0,0 +1,78 @@ +import copy +from typing import Any, Dict, Optional + + +def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any: + """ + Attempts to create a deep copy of the object using `copy.deepcopy` + whenever possible. If that fails, it falls back to custom deep copy + logic or returns the original object. + + Args: + obj (Any): The object to be copied, which can be of any type. + memo (Optional[Dict[int, Any]]): A dictionary used to track objects + that have already been copied to handle circular references. + If None, a new dictionary is created. + + Returns: + Any: A deep copy of the object if possible; otherwise, a shallow + copy if deep copying fails; if neither is possible, the original + object is returned. + """ + + if memo is None: + memo = {} + + if id(obj) in memo: + return memo[id(obj)] + + try: + # Try to use copy.deepcopy first + return copy.deepcopy(obj, memo) + except (TypeError, AttributeError): + # If deepcopy fails, handle specific types manually + + # Handle dictionaries + if isinstance(obj, dict): + new_obj = {} + memo[id(obj)] = new_obj + for k, v in obj.items(): + new_obj[k] = safe_deepcopy(v, memo) + return new_obj + + # Handle lists + elif isinstance(obj, list): + new_obj = [] + memo[id(obj)] = new_obj + for v in obj: + new_obj.append(safe_deepcopy(v, memo)) + return new_obj + + # Handle tuples (immutable, but might contain mutable objects) + elif isinstance(obj, tuple): + new_obj = tuple(safe_deepcopy(v, memo) for v in obj) + memo[id(obj)] = new_obj + return new_obj + + # Handle frozensets (immutable, but might contain mutable objects) + elif isinstance(obj, frozenset): + new_obj = frozenset(safe_deepcopy(v, memo) for v in obj) + memo[id(obj)] = new_obj + return new_obj + + # Handle objects with attributes + elif hasattr(obj, "__dict__"): + new_obj = obj.__new__(obj.__class__) + for attr in obj.__dict__: + setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr), memo)) + memo[id(obj)] = new_obj + return new_obj + + # Attempt shallow copy as a fallback + try: + return copy.copy(obj) + except (TypeError, AttributeError): + pass + + # If all else fails, return the original object + return obj diff --git a/tests/utils/copy_utils_test.py b/tests/utils/copy_utils_test.py new file mode 100644 index 00000000..d5d523a8 --- /dev/null +++ b/tests/utils/copy_utils_test.py @@ -0,0 +1,170 @@ +import copy +import pytest + +# Assuming the custom_deepcopy function is imported or defined above this line +from scrapegraphai.utils.copy import safe_deepcopy + + +class NormalObject: + def __init__(self, value): + self.value = value + self.nested = [1, 2, 3] + + def __deepcopy__(self, memo): + raise TypeError("Forcing fallback") + + +class NonDeepcopyable: + def __init__(self, value): + self.value = value + + def __deepcopy__(self, memo): + raise TypeError("Forcing shallow copy fallback") + + +class WithoutDict: + __slots__ = ["value"] + + def __init__(self, value): + self.value = value + + def __deepcopy__(self, memo): + raise TypeError("Forcing shallow copy fallback") + + def __copy__(self): + return self + + +class NonCopyableObject: + __slots__ = ["value"] + + def __init__(self, value): + self.value = value + + def __deepcopy__(self, memo): + raise TypeError("fail deep copy ") + + def __copy__(self): + raise TypeError("fail shallow copy") + + +def test_deepcopy_simple_dict(): + original = {"a": 1, "b": 2, "c": [3, 4, 5]} + copy_obj = safe_deepcopy(original) + assert copy_obj == original + assert copy_obj is not original + assert copy_obj["c"] is not original["c"] + + +def test_deepcopy_simple_list(): + original = [1, 2, 3, [4, 5]] + copy_obj = safe_deepcopy(original) + assert copy_obj == original + assert copy_obj is not original + assert copy_obj[3] is not original[3] + + +def test_deepcopy_with_tuple(): + original = (1, 2, [3, 4]) + copy_obj = safe_deepcopy(original) + assert copy_obj == original + assert copy_obj is not original + assert copy_obj[2] is not original[2] + + +def test_deepcopy_with_frozenset(): + original = frozenset([1, 2, 3, (4, 5)]) + copy_obj = safe_deepcopy(original) + assert copy_obj == original + assert copy_obj is not original + + +def test_deepcopy_with_object(): + original = NormalObject(10) + copy_obj = safe_deepcopy(original) + assert copy_obj.value == original.value + assert copy_obj is not original + assert copy_obj.nested is not original.nested + + +def test_deepcopy_with_custom_deepcopy_fallback(): + original = {"origin": NormalObject(10)} + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + assert copy_obj["origin"].value == original["origin"].value + + +def test_shallow_copy_fallback(): + original = {"origin": NonDeepcopyable(10)} + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + assert copy_obj["origin"].value == original["origin"].value + + +def test_circular_reference(): + original = [] + original.append(original) + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + assert copy_obj[0] is copy_obj + + +def test_memoization(): + original = {"a": 1, "b": 2} + memo = {} + copy_obj = safe_deepcopy(original, memo) + assert copy_obj is memo[id(original)] + + +def test_deepcopy_object_without_dict(): + original = {"origin": WithoutDict(10)} + copy_obj = safe_deepcopy(original) + assert copy_obj["origin"].value == original["origin"].value + assert copy_obj is not original + assert copy_obj["origin"] is original["origin"] + assert ( + hasattr(copy_obj["origin"], "__dict__") is False + ) # Ensure __dict__ is not present + + original = [WithoutDict(10)] + copy_obj = safe_deepcopy(original) + assert copy_obj[0].value == original[0].value + assert copy_obj is not original + assert copy_obj[0] is original[0] + + original = (WithoutDict(10),) + copy_obj = safe_deepcopy(original) + assert copy_obj[0].value == original[0].value + assert copy_obj is not original + assert copy_obj[0] is original[0] + + original_item = WithoutDict(10) + original = set([original_item]) + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + copy_obj_item = copy_obj.pop() + assert copy_obj_item.value == original_item.value + assert copy_obj_item is original_item + + original_item = WithoutDict(10) + original = frozenset([original_item]) + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + copy_obj_item = list(copy_obj)[0] + assert copy_obj_item.value == original_item.value + assert copy_obj_item is original_item + +def test_memo(): + obj = NormalObject(10) + original = {"origin": obj} + memo = {id(original):obj} + copy_obj = safe_deepcopy(original, memo) + assert copy_obj is memo[id(original)] + +def test_unhandled_type(): + original = {"origin": NonCopyableObject(10)} + copy_obj = safe_deepcopy(original) + assert copy_obj["origin"].value == original["origin"].value + assert copy_obj is not original + assert copy_obj["origin"] is original["origin"] + assert hasattr(copy_obj, "__dict__") is False # Ensure __dict__ is not present From cd07418474112cecd53ab47866262f2f31294223 Mon Sep 17 00:00:00 2001 From: smith peng Date: Sat, 31 Aug 2024 12:42:08 +0800 Subject: [PATCH 2/6] fix: deepcopy fail for coping model_instance config --- scrapegraphai/graphs/csv_scraper_multi_graph.py | 9 ++++----- scrapegraphai/graphs/json_scraper_multi_graph.py | 9 ++++----- scrapegraphai/graphs/markdown_scraper_multi_graph.py | 7 ++----- scrapegraphai/graphs/omni_search_graph.py | 8 +++----- scrapegraphai/graphs/pdf_scraper_multi_graph.py | 8 +++----- scrapegraphai/graphs/script_creator_multi_graph.py | 7 ++----- scrapegraphai/graphs/search_graph.py | 8 +++----- scrapegraphai/graphs/smart_scraper_multi_graph.py | 8 +++----- scrapegraphai/graphs/xml_scraper_multi_graph.py | 8 +++----- 9 files changed, 27 insertions(+), 45 deletions(-) diff --git a/scrapegraphai/graphs/csv_scraper_multi_graph.py b/scrapegraphai/graphs/csv_scraper_multi_graph.py index 59e84783..67498475 100644 --- a/scrapegraphai/graphs/csv_scraper_multi_graph.py +++ b/scrapegraphai/graphs/csv_scraper_multi_graph.py @@ -2,9 +2,10 @@ CSVScraperMultiGraph Module """ -from copy import copy, deepcopy from typing import List, Optional from pydantic import BaseModel + + from .base_graph import BaseGraph from .abstract_graph import AbstractGraph from .csv_scraper_graph import CSVScraperGraph @@ -12,6 +13,7 @@ from ..nodes import ( GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class CSVScraperMultiGraph(AbstractGraph): """ @@ -46,10 +48,7 @@ class CSVScraperMultiGraph(AbstractGraph): self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/json_scraper_multi_graph.py b/scrapegraphai/graphs/json_scraper_multi_graph.py index 42d2232e..c72d8afd 100644 --- a/scrapegraphai/graphs/json_scraper_multi_graph.py +++ b/scrapegraphai/graphs/json_scraper_multi_graph.py @@ -2,9 +2,10 @@ JSONScraperMultiGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import List, Optional from pydantic import BaseModel + from .base_graph import BaseGraph from .abstract_graph import AbstractGraph from .json_scraper_graph import JSONScraperGraph @@ -12,6 +13,7 @@ from ..nodes import ( GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class JSONScraperMultiGraph(AbstractGraph): """ @@ -45,10 +47,7 @@ class JSONScraperMultiGraph(AbstractGraph): self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/markdown_scraper_multi_graph.py b/scrapegraphai/graphs/markdown_scraper_multi_graph.py index 9796c11a..772eebe6 100644 --- a/scrapegraphai/graphs/markdown_scraper_multi_graph.py +++ b/scrapegraphai/graphs/markdown_scraper_multi_graph.py @@ -12,6 +12,7 @@ from ..nodes import ( GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class MDScraperMultiGraph(AbstractGraph): """ @@ -42,11 +43,7 @@ class MDScraperMultiGraph(AbstractGraph): """ def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None): - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) - + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/omni_search_graph.py b/scrapegraphai/graphs/omni_search_graph.py index b6f6df59..c005dbac 100644 --- a/scrapegraphai/graphs/omni_search_graph.py +++ b/scrapegraphai/graphs/omni_search_graph.py @@ -2,7 +2,7 @@ OmniSearchGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import Optional from pydantic import BaseModel @@ -15,6 +15,7 @@ from ..nodes import ( GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class OmniSearchGraph(AbstractGraph): @@ -48,10 +49,7 @@ class OmniSearchGraph(AbstractGraph): self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/pdf_scraper_multi_graph.py b/scrapegraphai/graphs/pdf_scraper_multi_graph.py index a7386267..06da6944 100644 --- a/scrapegraphai/graphs/pdf_scraper_multi_graph.py +++ b/scrapegraphai/graphs/pdf_scraper_multi_graph.py @@ -2,7 +2,7 @@ PdfScraperMultiGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import List, Optional from pydantic import BaseModel from .base_graph import BaseGraph @@ -12,6 +12,7 @@ from ..nodes import ( GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class PdfScraperMultiGraph(AbstractGraph): """ @@ -44,10 +45,7 @@ class PdfScraperMultiGraph(AbstractGraph): def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None): - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/script_creator_multi_graph.py b/scrapegraphai/graphs/script_creator_multi_graph.py index 969ba722..b2ea8465 100644 --- a/scrapegraphai/graphs/script_creator_multi_graph.py +++ b/scrapegraphai/graphs/script_creator_multi_graph.py @@ -2,7 +2,6 @@ ScriptCreatorMultiGraph Module """ -from copy import copy, deepcopy from typing import List, Optional from pydantic import BaseModel @@ -15,6 +14,7 @@ from ..nodes import ( GraphIteratorNode, MergeGeneratedScriptsNode ) +from ..utils.copy import safe_deepcopy class ScriptCreatorMultiGraph(AbstractGraph): """ @@ -47,10 +47,7 @@ class ScriptCreatorMultiGraph(AbstractGraph): self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index 080aaf19..d27e7186 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -2,7 +2,7 @@ SearchGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import Optional, List from pydantic import BaseModel @@ -15,6 +15,7 @@ from ..nodes import ( GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class SearchGraph(AbstractGraph): """ @@ -47,10 +48,7 @@ class SearchGraph(AbstractGraph): def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None): self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) self.considered_urls = [] # New attribute to store URLs diff --git a/scrapegraphai/graphs/smart_scraper_multi_graph.py b/scrapegraphai/graphs/smart_scraper_multi_graph.py index 66d53851..82585cf0 100644 --- a/scrapegraphai/graphs/smart_scraper_multi_graph.py +++ b/scrapegraphai/graphs/smart_scraper_multi_graph.py @@ -2,7 +2,7 @@ SmartScraperMultiGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import List, Optional from pydantic import BaseModel @@ -14,6 +14,7 @@ from ..nodes import ( GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class SmartScraperMultiGraph(AbstractGraph): """ @@ -48,10 +49,7 @@ class SmartScraperMultiGraph(AbstractGraph): self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/xml_scraper_multi_graph.py b/scrapegraphai/graphs/xml_scraper_multi_graph.py index 8050d50c..493d12ca 100644 --- a/scrapegraphai/graphs/xml_scraper_multi_graph.py +++ b/scrapegraphai/graphs/xml_scraper_multi_graph.py @@ -2,7 +2,7 @@ XMLScraperMultiGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import List, Optional from pydantic import BaseModel @@ -14,6 +14,7 @@ from ..nodes import ( GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class XMLScraperMultiGraph(AbstractGraph): """ @@ -46,10 +47,7 @@ class XMLScraperMultiGraph(AbstractGraph): def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None): - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) From 36818b1fb36e8253e8746ce4dd76dbf62076355f Mon Sep 17 00:00:00 2001 From: smith peng Date: Sat, 31 Aug 2024 17:39:33 +0800 Subject: [PATCH 3/6] feat:adjust uncopiable obj raise error and remove memo --- scrapegraphai/utils/copy.py | 49 ++++++++++++++------------------ tests/utils/copy_utils_test.py | 52 +++++++++++++++++++++------------- 2 files changed, 55 insertions(+), 46 deletions(-) diff --git a/scrapegraphai/utils/copy.py b/scrapegraphai/utils/copy.py index 4ccfcbf1..e1fdd37f 100644 --- a/scrapegraphai/utils/copy.py +++ b/scrapegraphai/utils/copy.py @@ -1,8 +1,9 @@ import copy from typing import Any, Dict, Optional +from pydantic.v1 import BaseModel -def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any: +def safe_deepcopy(obj: Any) -> Any: """ Attempts to create a deep copy of the object using `copy.deepcopy` whenever possible. If that fails, it falls back to custom deep copy @@ -10,9 +11,6 @@ def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any: Args: obj (Any): The object to be copied, which can be of any type. - memo (Optional[Dict[int, Any]]): A dictionary used to track objects - that have already been copied to handle circular references. - If None, a new dictionary is created. Returns: Any: A deep copy of the object if possible; otherwise, a shallow @@ -20,59 +18,56 @@ def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any: object is returned. """ - if memo is None: - memo = {} - - if id(obj) in memo: - return memo[id(obj)] - try: + # Try to use copy.deepcopy first - return copy.deepcopy(obj, memo) - except (TypeError, AttributeError): + if isinstance(obj,BaseModel): + # handle BaseModel because __fields_set__ need compatibility + copied_obj = obj.copy(deep=True) + else: + copied_obj = copy.deepcopy(obj) + + return copied_obj + except (TypeError, AttributeError) as e: # If deepcopy fails, handle specific types manually # Handle dictionaries if isinstance(obj, dict): new_obj = {} - memo[id(obj)] = new_obj + for k, v in obj.items(): - new_obj[k] = safe_deepcopy(v, memo) + new_obj[k] = safe_deepcopy(v) return new_obj # Handle lists elif isinstance(obj, list): new_obj = [] - memo[id(obj)] = new_obj + for v in obj: - new_obj.append(safe_deepcopy(v, memo)) + new_obj.append(safe_deepcopy(v)) return new_obj # Handle tuples (immutable, but might contain mutable objects) elif isinstance(obj, tuple): - new_obj = tuple(safe_deepcopy(v, memo) for v in obj) - memo[id(obj)] = new_obj + new_obj = tuple(safe_deepcopy(v) for v in obj) + return new_obj # Handle frozensets (immutable, but might contain mutable objects) elif isinstance(obj, frozenset): - new_obj = frozenset(safe_deepcopy(v, memo) for v in obj) - memo[id(obj)] = new_obj + new_obj = frozenset(safe_deepcopy(v) for v in obj) return new_obj # Handle objects with attributes elif hasattr(obj, "__dict__"): new_obj = obj.__new__(obj.__class__) for attr in obj.__dict__: - setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr), memo)) - memo[id(obj)] = new_obj + setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr))) + return new_obj - + # Attempt shallow copy as a fallback try: return copy.copy(obj) except (TypeError, AttributeError): - pass - - # If all else fails, return the original object - return obj + raise TypeError(f"Failed to create a deep copy obj") from e diff --git a/tests/utils/copy_utils_test.py b/tests/utils/copy_utils_test.py index d5d523a8..8fb5a804 100644 --- a/tests/utils/copy_utils_test.py +++ b/tests/utils/copy_utils_test.py @@ -3,16 +3,20 @@ import pytest # Assuming the custom_deepcopy function is imported or defined above this line from scrapegraphai.utils.copy import safe_deepcopy +from pydantic.v1 import BaseModel +from pydantic import BaseModel as BaseModelV2 +class PydantObject(BaseModel): + value: int + +class PydantObjectV2(BaseModelV2): + value: int class NormalObject: def __init__(self, value): self.value = value self.nested = [1, 2, 3] - def __deepcopy__(self, memo): - raise TypeError("Forcing fallback") - class NonDeepcopyable: def __init__(self, value): @@ -109,11 +113,6 @@ def test_circular_reference(): assert copy_obj[0] is copy_obj -def test_memoization(): - original = {"a": 1, "b": 2} - memo = {} - copy_obj = safe_deepcopy(original, memo) - assert copy_obj is memo[id(original)] def test_deepcopy_object_without_dict(): @@ -154,17 +153,32 @@ def test_deepcopy_object_without_dict(): assert copy_obj_item.value == original_item.value assert copy_obj_item is original_item -def test_memo(): - obj = NormalObject(10) - original = {"origin": obj} - memo = {id(original):obj} - copy_obj = safe_deepcopy(original, memo) - assert copy_obj is memo[id(original)] - def test_unhandled_type(): - original = {"origin": NonCopyableObject(10)} + with pytest.raises(TypeError): + original = {"origin": NonCopyableObject(10)} + copy_obj = safe_deepcopy(original) + +def test_client(): + llm_instance_config = { + "model": "moonshot-v1-8k", + "base_url": "https://api.moonshot.cn/v1", + "api_key": "xxx", + } + + from langchain_community.chat_models.moonshot import MoonshotChat + + llm_model_instance = MoonshotChat(**llm_instance_config) + + copy_obj = safe_deepcopy(llm_model_instance) + assert copy_obj + + +def test_circular_reference_in_dict(): + original = {} + original['self'] = original # Create a circular reference copy_obj = safe_deepcopy(original) - assert copy_obj["origin"].value == original["origin"].value + + # Check that the copy is a different object assert copy_obj is not original - assert copy_obj["origin"] is original["origin"] - assert hasattr(copy_obj, "__dict__") is False # Ensure __dict__ is not present + # Check that the circular reference is maintained in the copy + assert copy_obj['self'] is copy_obj From 71b22d48804c462798109bb47ec792a5a3c70b6e Mon Sep 17 00:00:00 2001 From: smith peng Date: Sat, 31 Aug 2024 17:55:14 +0800 Subject: [PATCH 4/6] feat: add deepcopy error --- scrapegraphai/utils/copy.py | 7 ++++++- tests/utils/copy_utils_test.py | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/scrapegraphai/utils/copy.py b/scrapegraphai/utils/copy.py index e1fdd37f..4f77f947 100644 --- a/scrapegraphai/utils/copy.py +++ b/scrapegraphai/utils/copy.py @@ -2,6 +2,9 @@ import copy from typing import Any, Dict, Optional from pydantic.v1 import BaseModel +class DeepCopyError(Exception): + """Custom exception raised when an object cannot be deep-copied.""" + pass def safe_deepcopy(obj: Any) -> Any: """ @@ -16,6 +19,8 @@ def safe_deepcopy(obj: Any) -> Any: Any: A deep copy of the object if possible; otherwise, a shallow copy if deep copying fails; if neither is possible, the original object is returned. + Raises: + DeepCopyError: If the object cannot be deep-copied or shallow-copied. """ try: @@ -70,4 +75,4 @@ def safe_deepcopy(obj: Any) -> Any: try: return copy.copy(obj) except (TypeError, AttributeError): - raise TypeError(f"Failed to create a deep copy obj") from e + raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e diff --git a/tests/utils/copy_utils_test.py b/tests/utils/copy_utils_test.py index 8fb5a804..3cb1d5fb 100644 --- a/tests/utils/copy_utils_test.py +++ b/tests/utils/copy_utils_test.py @@ -2,7 +2,7 @@ import copy import pytest # Assuming the custom_deepcopy function is imported or defined above this line -from scrapegraphai.utils.copy import safe_deepcopy +from scrapegraphai.utils.copy import DeepCopyError, safe_deepcopy from pydantic.v1 import BaseModel from pydantic import BaseModel as BaseModelV2 @@ -154,7 +154,7 @@ def test_deepcopy_object_without_dict(): assert copy_obj_item is original_item def test_unhandled_type(): - with pytest.raises(TypeError): + with pytest.raises(DeepCopyError): original = {"origin": NonCopyableObject(10)} copy_obj = safe_deepcopy(original) @@ -162,7 +162,7 @@ def test_client(): llm_instance_config = { "model": "moonshot-v1-8k", "base_url": "https://api.moonshot.cn/v1", - "api_key": "xxx", + "moonshot_api_key": "sk-OWo8hbSubp1QzOPyskOEwXQtZ867Ph0PZWCQdWrc3PH4o0lI", } from langchain_community.chat_models.moonshot import MoonshotChat From 553527a269cdd70c0c174ad5c78cbf35c00b22c1 Mon Sep 17 00:00:00 2001 From: smith peng Date: Sun, 1 Sep 2024 16:40:08 +0800 Subject: [PATCH 5/6] fix: fix pydantic object copy --- scrapegraphai/utils/copy.py | 27 ++++++++++++--------------- tests/utils/copy_utils_test.py | 16 +++++++++------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/scrapegraphai/utils/copy.py b/scrapegraphai/utils/copy.py index 4f77f947..2defbfa3 100644 --- a/scrapegraphai/utils/copy.py +++ b/scrapegraphai/utils/copy.py @@ -10,8 +10,8 @@ def safe_deepcopy(obj: Any) -> Any: """ Attempts to create a deep copy of the object using `copy.deepcopy` whenever possible. If that fails, it falls back to custom deep copy - logic or returns the original object. - + logic. If that also fails, it raises a `DeepCopyError`. + Args: obj (Any): The object to be copied, which can be of any type. @@ -26,13 +26,7 @@ def safe_deepcopy(obj: Any) -> Any: try: # Try to use copy.deepcopy first - if isinstance(obj,BaseModel): - # handle BaseModel because __fields_set__ need compatibility - copied_obj = obj.copy(deep=True) - else: - copied_obj = copy.deepcopy(obj) - - return copied_obj + return copy.deepcopy(obj) except (TypeError, AttributeError) as e: # If deepcopy fails, handle specific types manually @@ -65,14 +59,17 @@ def safe_deepcopy(obj: Any) -> Any: # Handle objects with attributes elif hasattr(obj, "__dict__"): - new_obj = obj.__new__(obj.__class__) - for attr in obj.__dict__: - setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr))) - - return new_obj - + # If an object cannot be deep copied, then the sub-properties of \ + # the object will not be analyzed and shallow copy will be used directly. + try: + return copy.copy(obj) + except (TypeError, AttributeError): + raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e + + # Attempt shallow copy as a fallback try: return copy.copy(obj) except (TypeError, AttributeError): raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e + diff --git a/tests/utils/copy_utils_test.py b/tests/utils/copy_utils_test.py index 3cb1d5fb..90c85d34 100644 --- a/tests/utils/copy_utils_test.py +++ b/tests/utils/copy_utils_test.py @@ -4,14 +4,10 @@ import pytest # Assuming the custom_deepcopy function is imported or defined above this line from scrapegraphai.utils.copy import DeepCopyError, safe_deepcopy from pydantic.v1 import BaseModel -from pydantic import BaseModel as BaseModelV2 class PydantObject(BaseModel): value: int -class PydantObjectV2(BaseModelV2): - value: int - class NormalObject: def __init__(self, value): self.value = value @@ -162,16 +158,16 @@ def test_client(): llm_instance_config = { "model": "moonshot-v1-8k", "base_url": "https://api.moonshot.cn/v1", - "moonshot_api_key": "sk-OWo8hbSubp1QzOPyskOEwXQtZ867Ph0PZWCQdWrc3PH4o0lI", + "moonshot_api_key": "xxx", } from langchain_community.chat_models.moonshot import MoonshotChat llm_model_instance = MoonshotChat(**llm_instance_config) - copy_obj = safe_deepcopy(llm_model_instance) + assert copy_obj - + assert hasattr(copy_obj, 'callbacks') def test_circular_reference_in_dict(): original = {} @@ -182,3 +178,9 @@ def test_circular_reference_in_dict(): assert copy_obj is not original # Check that the circular reference is maintained in the copy assert copy_obj['self'] is copy_obj + +def test_with_pydantic(): + original = PydantObject(value=1) + copy_obj = safe_deepcopy(original) + assert copy_obj.value == original.value + assert copy_obj is not original From 8422463ca6d53da80610a7be214cfb76753c2e8f Mon Sep 17 00:00:00 2001 From: smith peng Date: Mon, 2 Sep 2024 14:28:39 +0800 Subject: [PATCH 6/6] feat:expose the search engine params to user --- scrapegraphai/graphs/omni_search_graph.py | 3 ++- scrapegraphai/graphs/search_graph.py | 3 ++- scrapegraphai/nodes/search_internet_node.py | 6 +++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/scrapegraphai/graphs/omni_search_graph.py b/scrapegraphai/graphs/omni_search_graph.py index c005dbac..860685bd 100644 --- a/scrapegraphai/graphs/omni_search_graph.py +++ b/scrapegraphai/graphs/omni_search_graph.py @@ -83,7 +83,8 @@ class OmniSearchGraph(AbstractGraph): output=["urls"], node_config={ "llm_model": self.llm_model, - "max_results": self.max_results + "max_results": self.max_results, + "search_engine": self.copy_config.get("search_engine") } ) graph_iterator_node = GraphIteratorNode( diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index d27e7186..e34469be 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -76,7 +76,8 @@ class SearchGraph(AbstractGraph): output=["urls"], node_config={ "llm_model": self.llm_model, - "max_results": self.max_results + "max_results": self.max_results, + "search_engine": self.copy_config.get("search_engine") } ) graph_iterator_node = GraphIteratorNode( diff --git a/scrapegraphai/nodes/search_internet_node.py b/scrapegraphai/nodes/search_internet_node.py index df1b6277..14ce3207 100644 --- a/scrapegraphai/nodes/search_internet_node.py +++ b/scrapegraphai/nodes/search_internet_node.py @@ -41,7 +41,11 @@ class SearchInternetNode(BaseNode): self.verbose = ( False if node_config is None else node_config.get("verbose", False) ) - self.search_engine = node_config.get("search_engine", "google") + self.search_engine = ( + node_config["search_engine"] + if node_config.get("search_engine") + else "google" + ) self.max_results = node_config.get("max_results", 3) def execute(self, state: dict) -> dict: