Merge pull request #7 from ShawnJim/main

fix some error in chat_response.
This commit is contained in:
LanQian 2024-04-12 17:06:03 +08:00 committed by GitHub
commit 9aeeff3e61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 76 additions and 39 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
.env
/.idea/

37
app.py
View File

@ -1,6 +1,5 @@
from fastapi import FastAPI, Request, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from chatgpt.ChatService import ChatService
from utils.authorization import verify_token
from utils.retry import async_retry
@ -8,29 +7,47 @@ from utils.retry import async_retry
app = FastAPI()
async def to_send_conversation(request_data, access_token):
"""
对话前先决条件准备
:param request_data: 请求数据
:param access_token: 访问令牌
:return: chat_service
"""
chat_service = ChatService(request_data, access_token)
await chat_service.get_chat_requirements()
return chat_service
@app.post("/v1/chat/completions")
async def send_conversation(request: Request, token=Depends(verify_token)):
"""
发送对话
:param request: 请求入参
:param token: 访问令牌
:return: 对话结果
"""
access_token = None
if not token:
raise HTTPException(status_code=401, detail={"error": "Not authenticated"})
elif token.startswith("eyJhb"):
if token and token.startswith("ey"):
access_token = token
# 入参格式校验
try:
request_data = await request.json()
except Exception:
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
async def to_send_conversation(request_data):
chat_service = ChatService(request_data, access_token)
await chat_service.get_chat_requirements()
return chat_service
chat_service = await async_retry(to_send_conversation, request_data)
# 对话前先决条件准备
chat_service = await async_retry(to_send_conversation, request_data, access_token)
chat_service.prepare_send_conversation()
# 发送对话
stream = request_data.get("stream", False)
if stream is True:
# 流式响应
return StreamingResponse(await chat_service.send_conversation_for_stream(), media_type="text/event-stream")
else:
# 非流式响应
return JSONResponse(await chat_service.send_conversation(), media_type="application/json")

View File

@ -73,22 +73,58 @@ async def stream_response(response, model, max_tokens):
}
completion_tokens += 1
yield f"data: {json.dumps(chunk_new_data)}\n\n"
except Exception as e:
except Exception:
Logger.error(f"Error: {chunk}")
continue
async def chat_response(resp, model, prompt_tokens, max_tokens):
"""
组装对话响应
:param resp: 响应数据
:param model: 模型
:param prompt_tokens: prompt token数
:param max_tokens: 最大token数量
:return:
"""
last_resp = None
for i in reversed(resp):
if i != "data: [DONE]" and i.startswith("data: "):
if i != "data: [DONE]" and i.startswith("data: ") and '"message":' in i:
try:
last_resp = json.loads(i[6:])
break
except Exception as e:
except Exception:
Logger.error(f"Error: {i}")
continue
usage, system_fingerprint, finish_reason, message = await init_param(last_resp, max_tokens, model, prompt_tokens)
return {
"id": f"chatcmpl-{''.join(random.choice(string.ascii_letters + string.digits) for _ in range(29))}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": message,
"logprobs": None,
"finish_reason": finish_reason
}
],
"usage": usage,
"system_fingerprint": system_fingerprint,
"conversation_id": last_resp["conversation_id"]
}
async def init_param(last_resp, max_tokens, model, prompt_tokens):
"""
组装参数
:param last_resp: 最终响应
:param max_tokens: 最大token数量
:param model: 模型
:param prompt_tokens: prompt token数
:return: 对应参数
"""
if last_resp.get("type") == "moderation":
message_content = moderation_message
completion_tokens = 53
@ -96,16 +132,10 @@ async def chat_response(resp, model, prompt_tokens, max_tokens):
else:
message_content = last_resp["message"]["content"]["parts"][0]
message_content, completion_tokens, finish_reason = split_tokens_from_content(message_content, max_tokens, model)
chat_id = f"chatcmpl-{''.join(random.choice(string.ascii_letters + string.digits) for _ in range(29))}"
chat_object = "chat.completion"
created_time = int(time.time())
index = 0
message = {
"role": "assistant",
"content": message_content,
}
logprobs = None
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
@ -113,23 +143,7 @@ async def chat_response(resp, model, prompt_tokens, max_tokens):
}
system_fingerprint_list = model_system_fingerprint.get(model, None)
system_fingerprint = random.choice(system_fingerprint_list) if system_fingerprint_list else None
chat_response_json = {
"id": chat_id,
"object": chat_object,
"created": created_time,
"model": model,
"choices": [
{
"index": index,
"message": message,
"logprobs": logprobs,
"finish_reason": finish_reason
}
],
"usage": usage,
"system_fingerprint": system_fingerprint
}
return chat_response_json
return usage, system_fingerprint, finish_reason, message
def api_messages_to_chat(api_messages):
@ -157,6 +171,7 @@ class ChatService:
self.chat_token = None
self.data = data
self.conversation_id = self.data.get("conversation_id", None)
self.model = self.data.get("model", "gpt-3.5-turbo-0125")
self.api_messages = self.data.get("messages", [])
self.prompt_tokens = num_tokens_from_messages(self.api_messages, self.model)
@ -247,6 +262,8 @@ class ChatService:
"force_nulligen": False,
"force_rate_limit": False,
}
if self.conversation_id:
self.chat_request["conversation_id"] = self.conversation_id
return self.chat_request
async def send_conversation_for_stream(self):

View File

@ -1,15 +1,17 @@
from fastapi import Depends
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from utils.config import authorization_list
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def verify_token(token: str = Depends(oauth2_scheme)):
if not token:
return None
if not authorization_list or token in authorization_list:
return token
elif token.startswith("eyJhbGciOi"):
return token
else:
return False
raise HTTPException(status_code=401, detail="Not authenticated")