mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
101 lines
2.6 KiB
Python
101 lines
2.6 KiB
Python
"""
|
|
Functions to retrieve the correct output parser and format instructions for the LLM model.
|
|
"""
|
|
|
|
from typing import Any, Callable, Dict, Type, Union
|
|
|
|
from langchain_core.output_parsers import JsonOutputParser
|
|
from pydantic import BaseModel as BaseModelV2
|
|
from pydantic.v1 import BaseModel as BaseModelV1
|
|
|
|
|
|
def get_structured_output_parser(
|
|
schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]
|
|
) -> Callable:
|
|
"""
|
|
Get the correct output parser for the LLM model.
|
|
|
|
Returns:
|
|
Callable: The output parser function.
|
|
"""
|
|
if issubclass(schema, BaseModelV1):
|
|
return _base_model_v1_output_parser
|
|
|
|
if issubclass(schema, BaseModelV2):
|
|
return _base_model_v2_output_parser
|
|
|
|
return _dict_output_parser
|
|
|
|
|
|
def get_pydantic_output_parser(
|
|
schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]
|
|
) -> JsonOutputParser:
|
|
"""
|
|
Get the correct output parser for the LLM model.
|
|
|
|
Returns:
|
|
JsonOutputParser: The output parser object.
|
|
"""
|
|
if issubclass(schema, BaseModelV1):
|
|
raise ValueError(
|
|
"""pydantic.v1 and langchain_core.pydantic_v1
|
|
are not supported with this LLM model. Please use pydantic v2 instead."""
|
|
)
|
|
|
|
if issubclass(schema, BaseModelV2):
|
|
return JsonOutputParser(pydantic_object=schema)
|
|
|
|
raise ValueError(
|
|
"""The schema is not a pydantic subclass.
|
|
With this LLM model you must use a pydantic schemas."""
|
|
)
|
|
|
|
|
|
def _base_model_v1_output_parser(x: BaseModelV1) -> dict:
|
|
"""
|
|
Parse the output of an LLM when the schema is BaseModelv1.
|
|
|
|
Args:
|
|
x (BaseModelV1): The output from the LLM model.
|
|
|
|
Returns:
|
|
dict: The parsed output.
|
|
"""
|
|
work_dict = x.dict()
|
|
|
|
def recursive_dict_parser(work_dict: dict) -> dict:
|
|
dict_keys = work_dict.keys()
|
|
for key in dict_keys:
|
|
if isinstance(work_dict[key], BaseModelV1):
|
|
work_dict[key] = work_dict[key].dict()
|
|
recursive_dict_parser(work_dict[key])
|
|
return work_dict
|
|
|
|
return recursive_dict_parser(work_dict)
|
|
|
|
|
|
def _base_model_v2_output_parser(x: BaseModelV2) -> dict:
|
|
"""
|
|
Parse the output of an LLM when the schema is BaseModelv2.
|
|
|
|
Args:
|
|
x (BaseModelV2): The output from the LLM model.
|
|
|
|
Returns:
|
|
dict: The parsed output.
|
|
"""
|
|
return x.model_dump()
|
|
|
|
|
|
def _dict_output_parser(x: dict) -> dict:
|
|
"""
|
|
Parse the output of an LLM when the schema is TypedDict or JsonSchema.
|
|
|
|
Args:
|
|
x (dict): The output from the LLM model.
|
|
|
|
Returns:
|
|
dict: The parsed output.
|
|
"""
|
|
return x
|