feat:adjust uncopiable obj raise error and remove memo

This commit is contained in:
smith peng 2024-08-31 17:39:33 +08:00
parent cd07418474
commit 36818b1fb3
2 changed files with 55 additions and 46 deletions

View File

@ -1,8 +1,9 @@
import copy
from typing import Any, Dict, Optional
from pydantic.v1 import BaseModel
def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any:
def safe_deepcopy(obj: Any) -> Any:
"""
Attempts to create a deep copy of the object using `copy.deepcopy`
whenever possible. If that fails, it falls back to custom deep copy
@ -10,9 +11,6 @@ def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any:
Args:
obj (Any): The object to be copied, which can be of any type.
memo (Optional[Dict[int, Any]]): A dictionary used to track objects
that have already been copied to handle circular references.
If None, a new dictionary is created.
Returns:
Any: A deep copy of the object if possible; otherwise, a shallow
@ -20,59 +18,56 @@ def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any:
object is returned.
"""
if memo is None:
memo = {}
if id(obj) in memo:
return memo[id(obj)]
try:
# Try to use copy.deepcopy first
return copy.deepcopy(obj, memo)
except (TypeError, AttributeError):
if isinstance(obj,BaseModel):
# handle BaseModel because __fields_set__ need compatibility
copied_obj = obj.copy(deep=True)
else:
copied_obj = copy.deepcopy(obj)
return copied_obj
except (TypeError, AttributeError) as e:
# If deepcopy fails, handle specific types manually
# Handle dictionaries
if isinstance(obj, dict):
new_obj = {}
memo[id(obj)] = new_obj
for k, v in obj.items():
new_obj[k] = safe_deepcopy(v, memo)
new_obj[k] = safe_deepcopy(v)
return new_obj
# Handle lists
elif isinstance(obj, list):
new_obj = []
memo[id(obj)] = new_obj
for v in obj:
new_obj.append(safe_deepcopy(v, memo))
new_obj.append(safe_deepcopy(v))
return new_obj
# Handle tuples (immutable, but might contain mutable objects)
elif isinstance(obj, tuple):
new_obj = tuple(safe_deepcopy(v, memo) for v in obj)
memo[id(obj)] = new_obj
new_obj = tuple(safe_deepcopy(v) for v in obj)
return new_obj
# Handle frozensets (immutable, but might contain mutable objects)
elif isinstance(obj, frozenset):
new_obj = frozenset(safe_deepcopy(v, memo) for v in obj)
memo[id(obj)] = new_obj
new_obj = frozenset(safe_deepcopy(v) for v in obj)
return new_obj
# Handle objects with attributes
elif hasattr(obj, "__dict__"):
new_obj = obj.__new__(obj.__class__)
for attr in obj.__dict__:
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr), memo))
memo[id(obj)] = new_obj
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr)))
return new_obj
# Attempt shallow copy as a fallback
try:
return copy.copy(obj)
except (TypeError, AttributeError):
pass
# If all else fails, return the original object
return obj
raise TypeError(f"Failed to create a deep copy obj") from e

View File

@ -3,16 +3,20 @@ import pytest
# Assuming the custom_deepcopy function is imported or defined above this line
from scrapegraphai.utils.copy import safe_deepcopy
from pydantic.v1 import BaseModel
from pydantic import BaseModel as BaseModelV2
class PydantObject(BaseModel):
value: int
class PydantObjectV2(BaseModelV2):
value: int
class NormalObject:
def __init__(self, value):
self.value = value
self.nested = [1, 2, 3]
def __deepcopy__(self, memo):
raise TypeError("Forcing fallback")
class NonDeepcopyable:
def __init__(self, value):
@ -109,11 +113,6 @@ def test_circular_reference():
assert copy_obj[0] is copy_obj
def test_memoization():
original = {"a": 1, "b": 2}
memo = {}
copy_obj = safe_deepcopy(original, memo)
assert copy_obj is memo[id(original)]
def test_deepcopy_object_without_dict():
@ -154,17 +153,32 @@ def test_deepcopy_object_without_dict():
assert copy_obj_item.value == original_item.value
assert copy_obj_item is original_item
def test_memo():
obj = NormalObject(10)
original = {"origin": obj}
memo = {id(original):obj}
copy_obj = safe_deepcopy(original, memo)
assert copy_obj is memo[id(original)]
def test_unhandled_type():
original = {"origin": NonCopyableObject(10)}
with pytest.raises(TypeError):
original = {"origin": NonCopyableObject(10)}
copy_obj = safe_deepcopy(original)
def test_client():
llm_instance_config = {
"model": "moonshot-v1-8k",
"base_url": "https://api.moonshot.cn/v1",
"api_key": "xxx",
}
from langchain_community.chat_models.moonshot import MoonshotChat
llm_model_instance = MoonshotChat(**llm_instance_config)
copy_obj = safe_deepcopy(llm_model_instance)
assert copy_obj
def test_circular_reference_in_dict():
original = {}
original['self'] = original # Create a circular reference
copy_obj = safe_deepcopy(original)
assert copy_obj["origin"].value == original["origin"].value
# Check that the copy is a different object
assert copy_obj is not original
assert copy_obj["origin"] is original["origin"]
assert hasattr(copy_obj, "__dict__") is False # Ensure __dict__ is not present
# Check that the circular reference is maintained in the copy
assert copy_obj['self'] is copy_obj