feat: add deepcopy error

This commit is contained in:
smith peng 2024-08-31 17:55:14 +08:00
parent 36818b1fb3
commit 71b22d4880
2 changed files with 9 additions and 4 deletions

View File

@ -2,6 +2,9 @@ import copy
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic.v1 import BaseModel from pydantic.v1 import BaseModel
class DeepCopyError(Exception):
"""Custom exception raised when an object cannot be deep-copied."""
pass
def safe_deepcopy(obj: Any) -> Any: def safe_deepcopy(obj: Any) -> Any:
""" """
@ -16,6 +19,8 @@ def safe_deepcopy(obj: Any) -> Any:
Any: A deep copy of the object if possible; otherwise, a shallow Any: A deep copy of the object if possible; otherwise, a shallow
copy if deep copying fails; if neither is possible, the original copy if deep copying fails; if neither is possible, the original
object is returned. object is returned.
Raises:
DeepCopyError: If the object cannot be deep-copied or shallow-copied.
""" """
try: try:
@ -70,4 +75,4 @@ def safe_deepcopy(obj: Any) -> Any:
try: try:
return copy.copy(obj) return copy.copy(obj)
except (TypeError, AttributeError): except (TypeError, AttributeError):
raise TypeError(f"Failed to create a deep copy obj") from e raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e

View File

@ -2,7 +2,7 @@ import copy
import pytest import pytest
# Assuming the custom_deepcopy function is imported or defined above this line # Assuming the custom_deepcopy function is imported or defined above this line
from scrapegraphai.utils.copy import safe_deepcopy from scrapegraphai.utils.copy import DeepCopyError, safe_deepcopy
from pydantic.v1 import BaseModel from pydantic.v1 import BaseModel
from pydantic import BaseModel as BaseModelV2 from pydantic import BaseModel as BaseModelV2
@ -154,7 +154,7 @@ def test_deepcopy_object_without_dict():
assert copy_obj_item is original_item assert copy_obj_item is original_item
def test_unhandled_type(): def test_unhandled_type():
with pytest.raises(TypeError): with pytest.raises(DeepCopyError):
original = {"origin": NonCopyableObject(10)} original = {"origin": NonCopyableObject(10)}
copy_obj = safe_deepcopy(original) copy_obj = safe_deepcopy(original)
@ -162,7 +162,7 @@ def test_client():
llm_instance_config = { llm_instance_config = {
"model": "moonshot-v1-8k", "model": "moonshot-v1-8k",
"base_url": "https://api.moonshot.cn/v1", "base_url": "https://api.moonshot.cn/v1",
"api_key": "xxx", "moonshot_api_key": "sk-OWo8hbSubp1QzOPyskOEwXQtZ867Ph0PZWCQdWrc3PH4o0lI",
} }
from langchain_community.chat_models.moonshot import MoonshotChat from langchain_community.chat_models.moonshot import MoonshotChat