mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-28 21:01:55 +08:00
fix: fix pydantic object copy
This commit is contained in:
parent
71b22d4880
commit
553527a269
@ -10,8 +10,8 @@ def safe_deepcopy(obj: Any) -> Any:
|
||||
"""
|
||||
Attempts to create a deep copy of the object using `copy.deepcopy`
|
||||
whenever possible. If that fails, it falls back to custom deep copy
|
||||
logic or returns the original object.
|
||||
|
||||
logic. If that also fails, it raises a `DeepCopyError`.
|
||||
|
||||
Args:
|
||||
obj (Any): The object to be copied, which can be of any type.
|
||||
|
||||
@ -26,13 +26,7 @@ def safe_deepcopy(obj: Any) -> Any:
|
||||
try:
|
||||
|
||||
# Try to use copy.deepcopy first
|
||||
if isinstance(obj,BaseModel):
|
||||
# handle BaseModel because __fields_set__ need compatibility
|
||||
copied_obj = obj.copy(deep=True)
|
||||
else:
|
||||
copied_obj = copy.deepcopy(obj)
|
||||
|
||||
return copied_obj
|
||||
return copy.deepcopy(obj)
|
||||
except (TypeError, AttributeError) as e:
|
||||
# If deepcopy fails, handle specific types manually
|
||||
|
||||
@ -65,14 +59,17 @@ def safe_deepcopy(obj: Any) -> Any:
|
||||
|
||||
# Handle objects with attributes
|
||||
elif hasattr(obj, "__dict__"):
|
||||
new_obj = obj.__new__(obj.__class__)
|
||||
for attr in obj.__dict__:
|
||||
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr)))
|
||||
|
||||
return new_obj
|
||||
|
||||
# If an object cannot be deep copied, then the sub-properties of \
|
||||
# the object will not be analyzed and shallow copy will be used directly.
|
||||
try:
|
||||
return copy.copy(obj)
|
||||
except (TypeError, AttributeError):
|
||||
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
|
||||
|
||||
|
||||
# Attempt shallow copy as a fallback
|
||||
try:
|
||||
return copy.copy(obj)
|
||||
except (TypeError, AttributeError):
|
||||
raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e
|
||||
|
||||
|
||||
@ -4,14 +4,10 @@ import pytest
|
||||
# Assuming the custom_deepcopy function is imported or defined above this line
|
||||
from scrapegraphai.utils.copy import DeepCopyError, safe_deepcopy
|
||||
from pydantic.v1 import BaseModel
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
|
||||
class PydantObject(BaseModel):
|
||||
value: int
|
||||
|
||||
class PydantObjectV2(BaseModelV2):
|
||||
value: int
|
||||
|
||||
class NormalObject:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
@ -162,16 +158,16 @@ def test_client():
|
||||
llm_instance_config = {
|
||||
"model": "moonshot-v1-8k",
|
||||
"base_url": "https://api.moonshot.cn/v1",
|
||||
"moonshot_api_key": "sk-OWo8hbSubp1QzOPyskOEwXQtZ867Ph0PZWCQdWrc3PH4o0lI",
|
||||
"moonshot_api_key": "xxx",
|
||||
}
|
||||
|
||||
from langchain_community.chat_models.moonshot import MoonshotChat
|
||||
|
||||
llm_model_instance = MoonshotChat(**llm_instance_config)
|
||||
|
||||
copy_obj = safe_deepcopy(llm_model_instance)
|
||||
|
||||
assert copy_obj
|
||||
|
||||
assert hasattr(copy_obj, 'callbacks')
|
||||
|
||||
def test_circular_reference_in_dict():
|
||||
original = {}
|
||||
@ -182,3 +178,9 @@ def test_circular_reference_in_dict():
|
||||
assert copy_obj is not original
|
||||
# Check that the circular reference is maintained in the copy
|
||||
assert copy_obj['self'] is copy_obj
|
||||
|
||||
def test_with_pydantic():
|
||||
original = PydantObject(value=1)
|
||||
copy_obj = safe_deepcopy(original)
|
||||
assert copy_obj.value == original.value
|
||||
assert copy_obj is not original
|
||||
|
||||
Loading…
Reference in New Issue
Block a user