mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
feat:adjust uncopiable obj raise error and remove memo
This commit is contained in:
parent
cd07418474
commit
36818b1fb3
@ -1,8 +1,9 @@
|
||||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
|
||||
def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any:
|
||||
def safe_deepcopy(obj: Any) -> Any:
|
||||
"""
|
||||
Attempts to create a deep copy of the object using `copy.deepcopy`
|
||||
whenever possible. If that fails, it falls back to custom deep copy
|
||||
@ -10,9 +11,6 @@ def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any:
|
||||
|
||||
Args:
|
||||
obj (Any): The object to be copied, which can be of any type.
|
||||
memo (Optional[Dict[int, Any]]): A dictionary used to track objects
|
||||
that have already been copied to handle circular references.
|
||||
If None, a new dictionary is created.
|
||||
|
||||
Returns:
|
||||
Any: A deep copy of the object if possible; otherwise, a shallow
|
||||
@ -20,59 +18,56 @@ def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any:
|
||||
object is returned.
|
||||
"""
|
||||
|
||||
if memo is None:
|
||||
memo = {}
|
||||
|
||||
if id(obj) in memo:
|
||||
return memo[id(obj)]
|
||||
|
||||
try:
|
||||
|
||||
# Try to use copy.deepcopy first
|
||||
return copy.deepcopy(obj, memo)
|
||||
except (TypeError, AttributeError):
|
||||
if isinstance(obj,BaseModel):
|
||||
# handle BaseModel because __fields_set__ need compatibility
|
||||
copied_obj = obj.copy(deep=True)
|
||||
else:
|
||||
copied_obj = copy.deepcopy(obj)
|
||||
|
||||
return copied_obj
|
||||
except (TypeError, AttributeError) as e:
|
||||
# If deepcopy fails, handle specific types manually
|
||||
|
||||
# Handle dictionaries
|
||||
if isinstance(obj, dict):
|
||||
new_obj = {}
|
||||
memo[id(obj)] = new_obj
|
||||
|
||||
for k, v in obj.items():
|
||||
new_obj[k] = safe_deepcopy(v, memo)
|
||||
new_obj[k] = safe_deepcopy(v)
|
||||
return new_obj
|
||||
|
||||
# Handle lists
|
||||
elif isinstance(obj, list):
|
||||
new_obj = []
|
||||
memo[id(obj)] = new_obj
|
||||
|
||||
for v in obj:
|
||||
new_obj.append(safe_deepcopy(v, memo))
|
||||
new_obj.append(safe_deepcopy(v))
|
||||
return new_obj
|
||||
|
||||
# Handle tuples (immutable, but might contain mutable objects)
|
||||
elif isinstance(obj, tuple):
|
||||
new_obj = tuple(safe_deepcopy(v, memo) for v in obj)
|
||||
memo[id(obj)] = new_obj
|
||||
new_obj = tuple(safe_deepcopy(v) for v in obj)
|
||||
|
||||
return new_obj
|
||||
|
||||
# Handle frozensets (immutable, but might contain mutable objects)
|
||||
elif isinstance(obj, frozenset):
|
||||
new_obj = frozenset(safe_deepcopy(v, memo) for v in obj)
|
||||
memo[id(obj)] = new_obj
|
||||
new_obj = frozenset(safe_deepcopy(v) for v in obj)
|
||||
return new_obj
|
||||
|
||||
# Handle objects with attributes
|
||||
elif hasattr(obj, "__dict__"):
|
||||
new_obj = obj.__new__(obj.__class__)
|
||||
for attr in obj.__dict__:
|
||||
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr), memo))
|
||||
memo[id(obj)] = new_obj
|
||||
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr)))
|
||||
|
||||
return new_obj
|
||||
|
||||
|
||||
# Attempt shallow copy as a fallback
|
||||
try:
|
||||
return copy.copy(obj)
|
||||
except (TypeError, AttributeError):
|
||||
pass
|
||||
|
||||
# If all else fails, return the original object
|
||||
return obj
|
||||
raise TypeError(f"Failed to create a deep copy obj") from e
|
||||
|
||||
@ -3,16 +3,20 @@ import pytest
|
||||
|
||||
# Assuming the custom_deepcopy function is imported or defined above this line
|
||||
from scrapegraphai.utils.copy import safe_deepcopy
|
||||
from pydantic.v1 import BaseModel
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
|
||||
class PydantObject(BaseModel):
|
||||
value: int
|
||||
|
||||
class PydantObjectV2(BaseModelV2):
|
||||
value: int
|
||||
|
||||
class NormalObject:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
self.nested = [1, 2, 3]
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
raise TypeError("Forcing fallback")
|
||||
|
||||
|
||||
class NonDeepcopyable:
|
||||
def __init__(self, value):
|
||||
@ -109,11 +113,6 @@ def test_circular_reference():
|
||||
assert copy_obj[0] is copy_obj
|
||||
|
||||
|
||||
def test_memoization():
|
||||
original = {"a": 1, "b": 2}
|
||||
memo = {}
|
||||
copy_obj = safe_deepcopy(original, memo)
|
||||
assert copy_obj is memo[id(original)]
|
||||
|
||||
|
||||
def test_deepcopy_object_without_dict():
|
||||
@ -154,17 +153,32 @@ def test_deepcopy_object_without_dict():
|
||||
assert copy_obj_item.value == original_item.value
|
||||
assert copy_obj_item is original_item
|
||||
|
||||
def test_memo():
|
||||
obj = NormalObject(10)
|
||||
original = {"origin": obj}
|
||||
memo = {id(original):obj}
|
||||
copy_obj = safe_deepcopy(original, memo)
|
||||
assert copy_obj is memo[id(original)]
|
||||
|
||||
def test_unhandled_type():
|
||||
original = {"origin": NonCopyableObject(10)}
|
||||
with pytest.raises(TypeError):
|
||||
original = {"origin": NonCopyableObject(10)}
|
||||
copy_obj = safe_deepcopy(original)
|
||||
|
||||
def test_client():
|
||||
llm_instance_config = {
|
||||
"model": "moonshot-v1-8k",
|
||||
"base_url": "https://api.moonshot.cn/v1",
|
||||
"api_key": "xxx",
|
||||
}
|
||||
|
||||
from langchain_community.chat_models.moonshot import MoonshotChat
|
||||
|
||||
llm_model_instance = MoonshotChat(**llm_instance_config)
|
||||
|
||||
copy_obj = safe_deepcopy(llm_model_instance)
|
||||
assert copy_obj
|
||||
|
||||
|
||||
def test_circular_reference_in_dict():
|
||||
original = {}
|
||||
original['self'] = original # Create a circular reference
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj["origin"].value == original["origin"].value
|
||||
|
||||
# Check that the copy is a different object
|
||||
assert copy_obj is not original
|
||||
assert copy_obj["origin"] is original["origin"]
|
||||
assert hasattr(copy_obj, "__dict__") is False # Ensure __dict__ is not present
|
||||
# Check that the circular reference is maintained in the copy
|
||||
assert copy_obj['self'] is copy_obj
|
||||
|
||||
Loading…
Reference in New Issue
Block a user