fix: deepcopy fail for coping model_instance config

This commit is contained in:
smith peng 2024-08-31 12:42:08 +08:00
parent 4f4d091b82
commit cd07418474
9 changed files with 27 additions and 45 deletions

View File

@ -2,9 +2,10 @@
CSVScraperMultiGraph Module CSVScraperMultiGraph Module
""" """
from copy import copy, deepcopy
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from .base_graph import BaseGraph from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph from .abstract_graph import AbstractGraph
from .csv_scraper_graph import CSVScraperGraph from .csv_scraper_graph import CSVScraperGraph
@ -12,6 +13,7 @@ from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeAnswersNode MergeAnswersNode
) )
from ..utils.copy import safe_deepcopy
class CSVScraperMultiGraph(AbstractGraph): class CSVScraperMultiGraph(AbstractGraph):
""" """
@ -46,10 +48,7 @@ class CSVScraperMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3) self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()): self.copy_config = safe_deepcopy(config)
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
super().__init__(prompt, config, source, schema) super().__init__(prompt, config, source, schema)

View File

@ -2,9 +2,10 @@
JSONScraperMultiGraph Module JSONScraperMultiGraph Module
""" """
from copy import copy, deepcopy from copy import deepcopy
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from .base_graph import BaseGraph from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph from .abstract_graph import AbstractGraph
from .json_scraper_graph import JSONScraperGraph from .json_scraper_graph import JSONScraperGraph
@ -12,6 +13,7 @@ from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeAnswersNode MergeAnswersNode
) )
from ..utils.copy import safe_deepcopy
class JSONScraperMultiGraph(AbstractGraph): class JSONScraperMultiGraph(AbstractGraph):
""" """
@ -45,10 +47,7 @@ class JSONScraperMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3) self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()): self.copy_config = safe_deepcopy(config)
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_schema = deepcopy(schema) self.copy_schema = deepcopy(schema)

View File

@ -12,6 +12,7 @@ from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeAnswersNode MergeAnswersNode
) )
from ..utils.copy import safe_deepcopy
class MDScraperMultiGraph(AbstractGraph): class MDScraperMultiGraph(AbstractGraph):
""" """
@ -42,11 +43,7 @@ class MDScraperMultiGraph(AbstractGraph):
""" """
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None): def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):
if all(isinstance(value, str) for value in config.values()): self.copy_config = safe_deepcopy(config)
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_schema = deepcopy(schema) self.copy_schema = deepcopy(schema)
super().__init__(prompt, config, source, schema) super().__init__(prompt, config, source, schema)

View File

@ -2,7 +2,7 @@
OmniSearchGraph Module OmniSearchGraph Module
""" """
from copy import copy, deepcopy from copy import deepcopy
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -15,6 +15,7 @@ from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeAnswersNode MergeAnswersNode
) )
from ..utils.copy import safe_deepcopy
class OmniSearchGraph(AbstractGraph): class OmniSearchGraph(AbstractGraph):
@ -48,10 +49,7 @@ class OmniSearchGraph(AbstractGraph):
self.max_results = config.get("max_results", 3) self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()): self.copy_config = safe_deepcopy(config)
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_schema = deepcopy(schema) self.copy_schema = deepcopy(schema)

View File

@ -2,7 +2,7 @@
PdfScraperMultiGraph Module PdfScraperMultiGraph Module
""" """
from copy import copy, deepcopy from copy import deepcopy
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from .base_graph import BaseGraph from .base_graph import BaseGraph
@ -12,6 +12,7 @@ from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeAnswersNode MergeAnswersNode
) )
from ..utils.copy import safe_deepcopy
class PdfScraperMultiGraph(AbstractGraph): class PdfScraperMultiGraph(AbstractGraph):
""" """
@ -44,10 +45,7 @@ class PdfScraperMultiGraph(AbstractGraph):
def __init__(self, prompt: str, source: List[str], def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None): config: dict, schema: Optional[BaseModel] = None):
if all(isinstance(value, str) for value in config.values()): self.copy_config = safe_deepcopy(config)
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_schema = deepcopy(schema) self.copy_schema = deepcopy(schema)

View File

@ -2,7 +2,6 @@
ScriptCreatorMultiGraph Module ScriptCreatorMultiGraph Module
""" """
from copy import copy, deepcopy
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -15,6 +14,7 @@ from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeGeneratedScriptsNode MergeGeneratedScriptsNode
) )
from ..utils.copy import safe_deepcopy
class ScriptCreatorMultiGraph(AbstractGraph): class ScriptCreatorMultiGraph(AbstractGraph):
""" """
@ -47,10 +47,7 @@ class ScriptCreatorMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3) self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()): self.copy_config = safe_deepcopy(config)
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
super().__init__(prompt, config, source, schema) super().__init__(prompt, config, source, schema)

View File

@ -2,7 +2,7 @@
SearchGraph Module SearchGraph Module
""" """
from copy import copy, deepcopy from copy import deepcopy
from typing import Optional, List from typing import Optional, List
from pydantic import BaseModel from pydantic import BaseModel
@ -15,6 +15,7 @@ from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeAnswersNode MergeAnswersNode
) )
from ..utils.copy import safe_deepcopy
class SearchGraph(AbstractGraph): class SearchGraph(AbstractGraph):
""" """
@ -47,10 +48,7 @@ class SearchGraph(AbstractGraph):
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None): def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
self.max_results = config.get("max_results", 3) self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()): self.copy_config = safe_deepcopy(config)
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_schema = deepcopy(schema) self.copy_schema = deepcopy(schema)
self.considered_urls = [] # New attribute to store URLs self.considered_urls = [] # New attribute to store URLs

View File

@ -2,7 +2,7 @@
SmartScraperMultiGraph Module SmartScraperMultiGraph Module
""" """
from copy import copy, deepcopy from copy import deepcopy
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -14,6 +14,7 @@ from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeAnswersNode MergeAnswersNode
) )
from ..utils.copy import safe_deepcopy
class SmartScraperMultiGraph(AbstractGraph): class SmartScraperMultiGraph(AbstractGraph):
""" """
@ -48,10 +49,7 @@ class SmartScraperMultiGraph(AbstractGraph):
self.max_results = config.get("max_results", 3) self.max_results = config.get("max_results", 3)
if all(isinstance(value, str) for value in config.values()): self.copy_config = safe_deepcopy(config)
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_schema = deepcopy(schema) self.copy_schema = deepcopy(schema)

View File

@ -2,7 +2,7 @@
XMLScraperMultiGraph Module XMLScraperMultiGraph Module
""" """
from copy import copy, deepcopy from copy import deepcopy
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -14,6 +14,7 @@ from ..nodes import (
GraphIteratorNode, GraphIteratorNode,
MergeAnswersNode MergeAnswersNode
) )
from ..utils.copy import safe_deepcopy
class XMLScraperMultiGraph(AbstractGraph): class XMLScraperMultiGraph(AbstractGraph):
""" """
@ -46,10 +47,7 @@ class XMLScraperMultiGraph(AbstractGraph):
def __init__(self, prompt: str, source: List[str], def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None): config: dict, schema: Optional[BaseModel] = None):
if all(isinstance(value, str) for value in config.values()): self.copy_config = safe_deepcopy(config)
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)
self.copy_schema = deepcopy(schema) self.copy_schema = deepcopy(schema)