Scrapegraph-ai/scrapegraphai/utils/code_error_correction.py
2024-10-08 08:54:18 +02:00

103 lines
4.0 KiB
Python

"""
This module contains the functions for code generation to correct different types of errors.
Functions:
- syntax_focused_code_generation: Generates corrected code based on syntax error analysis.
- execution_focused_code_generation: Generates corrected code based on execution error analysis.
- validation_focused_code_generation: Generates corrected code based on
validation error analysis, considering JSON schema.
- semantic_focused_code_generation: Generates corrected code based on semantic error analysis,
comparing generated and reference results.
"""
import json
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from ..prompts import (
TEMPLATE_SYNTAX_CODE_GENERATION, TEMPLATE_EXECUTION_CODE_GENERATION,
TEMPLATE_VALIDATION_CODE_GENERATION, TEMPLATE_SEMANTIC_CODE_GENERATION
)
def syntax_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
"""
Generates corrected code based on syntax error analysis.
Args:
state (dict): Contains the 'generated_code'.
analysis (str): The analysis of the syntax errors.
llm_model: The language model used for generating the corrected code.
Returns:
str: The corrected code.
"""
prompt = PromptTemplate(template=TEMPLATE_SYNTAX_CODE_GENERATION,
input_variables=["analysis", "generated_code"])
chain = prompt | llm_model | StrOutputParser()
return chain.invoke({
"analysis": analysis,
"generated_code": state["generated_code"]
})
def execution_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
"""
Generates corrected code based on execution error analysis.
Args:
state (dict): Contains the 'generated_code'.
analysis (str): The analysis of the execution errors.
llm_model: The language model used for generating the corrected code.
Returns:
str: The corrected code.
"""
prompt = PromptTemplate(template=TEMPLATE_EXECUTION_CODE_GENERATION,
input_variables=["analysis", "generated_code"])
chain = prompt | llm_model | StrOutputParser()
return chain.invoke({
"analysis": analysis,
"generated_code": state["generated_code"]
})
def validation_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
"""
Generates corrected code based on validation error analysis.
Args:
state (dict): Contains the 'generated_code' and 'json_schema'.
analysis (str): The analysis of the validation errors.
llm_model: The language model used for generating the corrected code.
Returns:
str: The corrected code.
"""
prompt = PromptTemplate(template=TEMPLATE_VALIDATION_CODE_GENERATION,
input_variables=["analysis", "generated_code", "json_schema"])
chain = prompt | llm_model | StrOutputParser()
return chain.invoke({
"analysis": analysis,
"generated_code": state["generated_code"],
"json_schema": state["json_schema"]
})
def semantic_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
"""
Generates corrected code based on semantic error analysis.
Args:
state (dict): Contains the 'generated_code', 'execution_result', and 'reference_answer'.
analysis (str): The analysis of the semantic differences.
llm_model: The language model used for generating the corrected code.
Returns:
str: The corrected code.
"""
prompt = PromptTemplate(template=TEMPLATE_SEMANTIC_CODE_GENERATION,
input_variables=["analysis", "generated_code", "generated_result", "reference_result"])
chain = prompt | llm_model | StrOutputParser()
return chain.invoke({
"analysis": analysis,
"generated_code": state["generated_code"],
"generated_result": json.dumps(state["execution_result"], indent=2),
"reference_result": json.dumps(state["reference_answer"], indent=2)
})