fix: deepcopy fail for coping model_instance config

This commit is contained in:
smith peng 2024-08-31 12:42:08 +08:00
parent 4f4d091b82
commit cd07418474
9 changed files with 27 additions and 45 deletions

View File

@ -2,9 +2,10 @@
CSVScraperMultiGraph Module
"""
from copy import copy, deepcopy
from typing import List, Optional
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from .csv_scraper_graph import CSVScraperGraph
@ -12,6 +13,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class CSVScraperMultiGraph(AbstractGraph):
"""
@ -46,10 +48,7 @@ class CSVScraperMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
super().__init__(prompt, config, source, schema)

View File

@ -2,9 +2,10 @@
JSONScraperMultiGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from .json_scraper_graph import JSONScraperGraph
@ -12,6 +13,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class JSONScraperMultiGraph(AbstractGraph):
"""
@ -45,10 +47,7 @@ class JSONScraperMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

View File

@ -12,6 +12,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class MDScraperMultiGraph(AbstractGraph):
"""
@ -42,11 +43,7 @@ class MDScraperMultiGraph(AbstractGraph):
"""
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)
super().__init__(prompt, config, source, schema)

View File

@ -2,7 +2,7 @@
OmniSearchGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import Optional
from pydantic import BaseModel
@ -15,6 +15,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class OmniSearchGraph(AbstractGraph):
@ -48,10 +49,7 @@ class OmniSearchGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

View File

@ -2,7 +2,7 @@
PdfScraperMultiGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
from .base_graph import BaseGraph
@ -12,6 +12,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class PdfScraperMultiGraph(AbstractGraph):
"""
@ -44,10 +45,7 @@ class PdfScraperMultiGraph(AbstractGraph):
def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None):
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

View File

@ -2,7 +2,6 @@
ScriptCreatorMultiGraph Module
"""
from copy import copy, deepcopy
from typing import List, Optional
from pydantic import BaseModel
@ -15,6 +14,7 @@ from ..nodes import (
GraphIteratorNode,
MergeGeneratedScriptsNode
)
from ..utils.copy import safe_deepcopy
class ScriptCreatorMultiGraph(AbstractGraph):
"""
@ -47,10 +47,7 @@ class ScriptCreatorMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
super().__init__(prompt, config, source, schema)

View File

@ -2,7 +2,7 @@
SearchGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import Optional, List
from pydantic import BaseModel
@ -15,6 +15,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class SearchGraph(AbstractGraph):
"""
@ -47,10 +48,7 @@ class SearchGraph(AbstractGraph):
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)
self.considered_urls = [] # New attribute to store URLs

View File

@ -2,7 +2,7 @@
SmartScraperMultiGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
@ -14,6 +14,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class SmartScraperMultiGraph(AbstractGraph):
"""
@ -48,10 +49,7 @@ class SmartScraperMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

View File

@ -2,7 +2,7 @@
XMLScraperMultiGraph Module
"""
from copy import copy, deepcopy
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
@ -14,6 +14,7 @@ from ..nodes import (
GraphIteratorNode,
MergeAnswersNode
)
from ..utils.copy import safe_deepcopy
class XMLScraperMultiGraph(AbstractGraph):
"""
@ -46,10 +47,7 @@ class XMLScraperMultiGraph(AbstractGraph):
def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None):
if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)