mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-12 21:01:54 +08:00
103 lines
4.0 KiB
Python
103 lines
4.0 KiB
Python
"""
|
|
This module contains the functions for code generation to correct different types of errors.
|
|
|
|
Functions:
|
|
- syntax_focused_code_generation: Generates corrected code based on syntax error analysis.
|
|
- execution_focused_code_generation: Generates corrected code based on execution error analysis.
|
|
- validation_focused_code_generation: Generates corrected code based on
|
|
validation error analysis, considering JSON schema.
|
|
- semantic_focused_code_generation: Generates corrected code based on semantic error analysis,
|
|
comparing generated and reference results.
|
|
"""
|
|
|
|
import json
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from ..prompts import (
|
|
TEMPLATE_SYNTAX_CODE_GENERATION, TEMPLATE_EXECUTION_CODE_GENERATION,
|
|
TEMPLATE_VALIDATION_CODE_GENERATION, TEMPLATE_SEMANTIC_CODE_GENERATION
|
|
)
|
|
|
|
def syntax_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
|
|
"""
|
|
Generates corrected code based on syntax error analysis.
|
|
|
|
Args:
|
|
state (dict): Contains the 'generated_code'.
|
|
analysis (str): The analysis of the syntax errors.
|
|
llm_model: The language model used for generating the corrected code.
|
|
|
|
Returns:
|
|
str: The corrected code.
|
|
"""
|
|
prompt = PromptTemplate(template=TEMPLATE_SYNTAX_CODE_GENERATION,
|
|
input_variables=["analysis", "generated_code"])
|
|
chain = prompt | llm_model | StrOutputParser()
|
|
return chain.invoke({
|
|
"analysis": analysis,
|
|
"generated_code": state["generated_code"]
|
|
})
|
|
|
|
def execution_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
|
|
"""
|
|
Generates corrected code based on execution error analysis.
|
|
|
|
Args:
|
|
state (dict): Contains the 'generated_code'.
|
|
analysis (str): The analysis of the execution errors.
|
|
llm_model: The language model used for generating the corrected code.
|
|
|
|
Returns:
|
|
str: The corrected code.
|
|
"""
|
|
prompt = PromptTemplate(template=TEMPLATE_EXECUTION_CODE_GENERATION,
|
|
input_variables=["analysis", "generated_code"])
|
|
chain = prompt | llm_model | StrOutputParser()
|
|
return chain.invoke({
|
|
"analysis": analysis,
|
|
"generated_code": state["generated_code"]
|
|
})
|
|
|
|
def validation_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
|
|
"""
|
|
Generates corrected code based on validation error analysis.
|
|
|
|
Args:
|
|
state (dict): Contains the 'generated_code' and 'json_schema'.
|
|
analysis (str): The analysis of the validation errors.
|
|
llm_model: The language model used for generating the corrected code.
|
|
|
|
Returns:
|
|
str: The corrected code.
|
|
"""
|
|
prompt = PromptTemplate(template=TEMPLATE_VALIDATION_CODE_GENERATION,
|
|
input_variables=["analysis", "generated_code", "json_schema"])
|
|
chain = prompt | llm_model | StrOutputParser()
|
|
return chain.invoke({
|
|
"analysis": analysis,
|
|
"generated_code": state["generated_code"],
|
|
"json_schema": state["json_schema"]
|
|
})
|
|
|
|
def semantic_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
|
|
"""
|
|
Generates corrected code based on semantic error analysis.
|
|
|
|
Args:
|
|
state (dict): Contains the 'generated_code', 'execution_result', and 'reference_answer'.
|
|
analysis (str): The analysis of the semantic differences.
|
|
llm_model: The language model used for generating the corrected code.
|
|
|
|
Returns:
|
|
str: The corrected code.
|
|
"""
|
|
prompt = PromptTemplate(template=TEMPLATE_SEMANTIC_CODE_GENERATION,
|
|
input_variables=["analysis", "generated_code", "generated_result", "reference_result"])
|
|
chain = prompt | llm_model | StrOutputParser()
|
|
return chain.invoke({
|
|
"analysis": analysis,
|
|
"generated_code": state["generated_code"],
|
|
"generated_result": json.dumps(state["execution_result"], indent=2),
|
|
"reference_result": json.dumps(state["reference_answer"], indent=2)
|
|
})
|