fix: fix pydantic object copy

This commit is contained in:
smith peng 2024-09-01 16:40:08 +08:00
parent 71b22d4880
commit 553527a269
2 changed files with 21 additions and 22 deletions

View File

@ -10,8 +10,8 @@ def safe_deepcopy(obj: Any) -> Any:
"""
Attempts to create a deep copy of the object using `copy.deepcopy`
whenever possible. If that fails, it falls back to custom deep copy
logic or returns the original object.
logic. If that also fails, it raises a `DeepCopyError`.
Args:
obj (Any): The object to be copied, which can be of any type.
@ -26,13 +26,7 @@ def safe_deepcopy(obj: Any) -> Any:
try:
# Try to use copy.deepcopy first
if isinstance(obj,BaseModel):
# handle BaseModel because __fields_set__ need compatibility
copied_obj = obj.copy(deep=True)
else:
copied_obj = copy.deepcopy(obj)
return copied_obj
return copy.deepcopy(obj)
except (TypeError, AttributeError) as e:
# If deepcopy fails, handle specific types manually
@ -65,14 +59,17 @@ def safe_deepcopy(obj: Any) -> Any:
# Handle objects with attributes
elif hasattr(obj, "__dict__"):
new_obj = obj.__new__(obj.__class__)
for attr in obj.__dict__:
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr)))
return new_obj
# If an object cannot be deep copied, then the sub-properties of \
# the object will not be analyzed and shallow copy will be used directly.
try:
return copy.copy(obj)
except (TypeError, AttributeError):
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
# Attempt shallow copy as a fallback
try:
return copy.copy(obj)
except (TypeError, AttributeError):
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e

View File

@ -4,14 +4,10 @@ import pytest
# Assuming the custom_deepcopy function is imported or defined above this line
from scrapegraphai.utils.copy import DeepCopyError, safe_deepcopy
from pydantic.v1 import BaseModel
from pydantic import BaseModel as BaseModelV2
class PydantObject(BaseModel):
value: int
class PydantObjectV2(BaseModelV2):
value: int
class NormalObject:
def __init__(self, value):
self.value = value
@ -162,16 +158,16 @@ def test_client():
llm_instance_config = {
"model": "moonshot-v1-8k",
"base_url": "https://api.moonshot.cn/v1",
"moonshot_api_key": "sk-OWo8hbSubp1QzOPyskOEwXQtZ867Ph0PZWCQdWrc3PH4o0lI",
"moonshot_api_key": "xxx",
}
from langchain_community.chat_models.moonshot import MoonshotChat
llm_model_instance = MoonshotChat(**llm_instance_config)
copy_obj = safe_deepcopy(llm_model_instance)
assert copy_obj
assert hasattr(copy_obj, 'callbacks')
def test_circular_reference_in_dict():
original = {}
@ -182,3 +178,9 @@ def test_circular_reference_in_dict():
assert copy_obj is not original
# Check that the circular reference is maintained in the copy
assert copy_obj['self'] is copy_obj
def test_with_pydantic():
original = PydantObject(value=1)
copy_obj = safe_deepcopy(original)
assert copy_obj.value == original.value
assert copy_obj is not original