diff --git a/scrapegraphai/utils/copy.py b/scrapegraphai/utils/copy.py index 4ccfcbf1..e1fdd37f 100644 --- a/scrapegraphai/utils/copy.py +++ b/scrapegraphai/utils/copy.py @@ -1,8 +1,9 @@ import copy from typing import Any, Dict, Optional +from pydantic.v1 import BaseModel -def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any: +def safe_deepcopy(obj: Any) -> Any: """ Attempts to create a deep copy of the object using `copy.deepcopy` whenever possible. If that fails, it falls back to custom deep copy @@ -10,9 +11,6 @@ def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any: Args: obj (Any): The object to be copied, which can be of any type. - memo (Optional[Dict[int, Any]]): A dictionary used to track objects - that have already been copied to handle circular references. - If None, a new dictionary is created. Returns: Any: A deep copy of the object if possible; otherwise, a shallow @@ -20,59 +18,56 @@ def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any: object is returned. """ - if memo is None: - memo = {} - - if id(obj) in memo: - return memo[id(obj)] - try: + # Try to use copy.deepcopy first - return copy.deepcopy(obj, memo) - except (TypeError, AttributeError): + if isinstance(obj,BaseModel): + # handle BaseModel because __fields_set__ need compatibility + copied_obj = obj.copy(deep=True) + else: + copied_obj = copy.deepcopy(obj) + + return copied_obj + except (TypeError, AttributeError) as e: # If deepcopy fails, handle specific types manually # Handle dictionaries if isinstance(obj, dict): new_obj = {} - memo[id(obj)] = new_obj + for k, v in obj.items(): - new_obj[k] = safe_deepcopy(v, memo) + new_obj[k] = safe_deepcopy(v) return new_obj # Handle lists elif isinstance(obj, list): new_obj = [] - memo[id(obj)] = new_obj + for v in obj: - new_obj.append(safe_deepcopy(v, memo)) + new_obj.append(safe_deepcopy(v)) return new_obj # Handle tuples (immutable, but might contain mutable objects) elif isinstance(obj, tuple): - new_obj = tuple(safe_deepcopy(v, memo) for v in obj) - memo[id(obj)] = new_obj + new_obj = tuple(safe_deepcopy(v) for v in obj) + return new_obj # Handle frozensets (immutable, but might contain mutable objects) elif isinstance(obj, frozenset): - new_obj = frozenset(safe_deepcopy(v, memo) for v in obj) - memo[id(obj)] = new_obj + new_obj = frozenset(safe_deepcopy(v) for v in obj) return new_obj # Handle objects with attributes elif hasattr(obj, "__dict__"): new_obj = obj.__new__(obj.__class__) for attr in obj.__dict__: - setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr), memo)) - memo[id(obj)] = new_obj + setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr))) + return new_obj - + # Attempt shallow copy as a fallback try: return copy.copy(obj) except (TypeError, AttributeError): - pass - - # If all else fails, return the original object - return obj + raise TypeError(f"Failed to create a deep copy obj") from e diff --git a/tests/utils/copy_utils_test.py b/tests/utils/copy_utils_test.py index d5d523a8..8fb5a804 100644 --- a/tests/utils/copy_utils_test.py +++ b/tests/utils/copy_utils_test.py @@ -3,16 +3,20 @@ import pytest # Assuming the custom_deepcopy function is imported or defined above this line from scrapegraphai.utils.copy import safe_deepcopy +from pydantic.v1 import BaseModel +from pydantic import BaseModel as BaseModelV2 +class PydantObject(BaseModel): + value: int + +class PydantObjectV2(BaseModelV2): + value: int class NormalObject: def __init__(self, value): self.value = value self.nested = [1, 2, 3] - def __deepcopy__(self, memo): - raise TypeError("Forcing fallback") - class NonDeepcopyable: def __init__(self, value): @@ -109,11 +113,6 @@ def test_circular_reference(): assert copy_obj[0] is copy_obj -def test_memoization(): - original = {"a": 1, "b": 2} - memo = {} - copy_obj = safe_deepcopy(original, memo) - assert copy_obj is memo[id(original)] def test_deepcopy_object_without_dict(): @@ -154,17 +153,32 @@ def test_deepcopy_object_without_dict(): assert copy_obj_item.value == original_item.value assert copy_obj_item is original_item -def test_memo(): - obj = NormalObject(10) - original = {"origin": obj} - memo = {id(original):obj} - copy_obj = safe_deepcopy(original, memo) - assert copy_obj is memo[id(original)] - def test_unhandled_type(): - original = {"origin": NonCopyableObject(10)} + with pytest.raises(TypeError): + original = {"origin": NonCopyableObject(10)} + copy_obj = safe_deepcopy(original) + +def test_client(): + llm_instance_config = { + "model": "moonshot-v1-8k", + "base_url": "https://api.moonshot.cn/v1", + "api_key": "xxx", + } + + from langchain_community.chat_models.moonshot import MoonshotChat + + llm_model_instance = MoonshotChat(**llm_instance_config) + + copy_obj = safe_deepcopy(llm_model_instance) + assert copy_obj + + +def test_circular_reference_in_dict(): + original = {} + original['self'] = original # Create a circular reference copy_obj = safe_deepcopy(original) - assert copy_obj["origin"].value == original["origin"].value + + # Check that the copy is a different object assert copy_obj is not original - assert copy_obj["origin"] is original["origin"] - assert hasattr(copy_obj, "__dict__") is False # Ensure __dict__ is not present + # Check that the circular reference is maintained in the copy + assert copy_obj['self'] is copy_obj