mirror of
https://github.com/lanqian528/chat2api.git
synced 2026-06-19 21:05:19 +08:00
Merge pull request #7 from ShawnJim/main
fix some error in chat_response.
This commit is contained in:
commit
9aeeff3e61
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
.env
|
||||
/.idea/
|
||||
|
||||
37
app.py
37
app.py
@ -1,6 +1,5 @@
|
||||
from fastapi import FastAPI, Request, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
|
||||
from chatgpt.ChatService import ChatService
|
||||
from utils.authorization import verify_token
|
||||
from utils.retry import async_retry
|
||||
@ -8,29 +7,47 @@ from utils.retry import async_retry
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
async def to_send_conversation(request_data, access_token):
|
||||
"""
|
||||
对话前先决条件准备
|
||||
:param request_data: 请求数据
|
||||
:param access_token: 访问令牌
|
||||
:return: chat_service
|
||||
"""
|
||||
chat_service = ChatService(request_data, access_token)
|
||||
await chat_service.get_chat_requirements()
|
||||
return chat_service
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def send_conversation(request: Request, token=Depends(verify_token)):
|
||||
"""
|
||||
发送对话
|
||||
:param request: 请求入参
|
||||
:param token: 访问令牌
|
||||
:return: 对话结果
|
||||
"""
|
||||
access_token = None
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail={"error": "Not authenticated"})
|
||||
elif token.startswith("eyJhb"):
|
||||
if token and token.startswith("ey"):
|
||||
access_token = token
|
||||
|
||||
# 入参格式校验
|
||||
try:
|
||||
request_data = await request.json()
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
|
||||
|
||||
async def to_send_conversation(request_data):
|
||||
chat_service = ChatService(request_data, access_token)
|
||||
await chat_service.get_chat_requirements()
|
||||
return chat_service
|
||||
|
||||
chat_service = await async_retry(to_send_conversation, request_data)
|
||||
# 对话前先决条件准备
|
||||
chat_service = await async_retry(to_send_conversation, request_data, access_token)
|
||||
chat_service.prepare_send_conversation()
|
||||
|
||||
# 发送对话
|
||||
stream = request_data.get("stream", False)
|
||||
if stream is True:
|
||||
# 流式响应
|
||||
return StreamingResponse(await chat_service.send_conversation_for_stream(), media_type="text/event-stream")
|
||||
else:
|
||||
# 非流式响应
|
||||
return JSONResponse(await chat_service.send_conversation(), media_type="application/json")
|
||||
|
||||
|
||||
|
||||
@ -73,22 +73,58 @@ async def stream_response(response, model, max_tokens):
|
||||
}
|
||||
completion_tokens += 1
|
||||
yield f"data: {json.dumps(chunk_new_data)}\n\n"
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
Logger.error(f"Error: {chunk}")
|
||||
continue
|
||||
|
||||
|
||||
async def chat_response(resp, model, prompt_tokens, max_tokens):
|
||||
"""
|
||||
组装对话响应
|
||||
:param resp: 响应数据
|
||||
:param model: 模型
|
||||
:param prompt_tokens: prompt token数
|
||||
:param max_tokens: 最大token数量
|
||||
:return:
|
||||
"""
|
||||
last_resp = None
|
||||
for i in reversed(resp):
|
||||
if i != "data: [DONE]" and i.startswith("data: "):
|
||||
if i != "data: [DONE]" and i.startswith("data: ") and '"message":' in i:
|
||||
try:
|
||||
last_resp = json.loads(i[6:])
|
||||
break
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
Logger.error(f"Error: {i}")
|
||||
continue
|
||||
usage, system_fingerprint, finish_reason, message = await init_param(last_resp, max_tokens, model, prompt_tokens)
|
||||
return {
|
||||
"id": f"chatcmpl-{''.join(random.choice(string.ascii_letters + string.digits) for _ in range(29))}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason
|
||||
}
|
||||
],
|
||||
"usage": usage,
|
||||
"system_fingerprint": system_fingerprint,
|
||||
"conversation_id": last_resp["conversation_id"]
|
||||
}
|
||||
|
||||
|
||||
async def init_param(last_resp, max_tokens, model, prompt_tokens):
|
||||
"""
|
||||
组装参数
|
||||
:param last_resp: 最终响应
|
||||
:param max_tokens: 最大token数量
|
||||
:param model: 模型
|
||||
:param prompt_tokens: prompt token数
|
||||
:return: 对应参数
|
||||
"""
|
||||
if last_resp.get("type") == "moderation":
|
||||
message_content = moderation_message
|
||||
completion_tokens = 53
|
||||
@ -96,16 +132,10 @@ async def chat_response(resp, model, prompt_tokens, max_tokens):
|
||||
else:
|
||||
message_content = last_resp["message"]["content"]["parts"][0]
|
||||
message_content, completion_tokens, finish_reason = split_tokens_from_content(message_content, max_tokens, model)
|
||||
chat_id = f"chatcmpl-{''.join(random.choice(string.ascii_letters + string.digits) for _ in range(29))}"
|
||||
chat_object = "chat.completion"
|
||||
created_time = int(time.time())
|
||||
index = 0
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": message_content,
|
||||
|
||||
}
|
||||
logprobs = None
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
@ -113,23 +143,7 @@ async def chat_response(resp, model, prompt_tokens, max_tokens):
|
||||
}
|
||||
system_fingerprint_list = model_system_fingerprint.get(model, None)
|
||||
system_fingerprint = random.choice(system_fingerprint_list) if system_fingerprint_list else None
|
||||
chat_response_json = {
|
||||
"id": chat_id,
|
||||
"object": chat_object,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": index,
|
||||
"message": message,
|
||||
"logprobs": logprobs,
|
||||
"finish_reason": finish_reason
|
||||
}
|
||||
],
|
||||
"usage": usage,
|
||||
"system_fingerprint": system_fingerprint
|
||||
}
|
||||
return chat_response_json
|
||||
return usage, system_fingerprint, finish_reason, message
|
||||
|
||||
|
||||
def api_messages_to_chat(api_messages):
|
||||
@ -157,6 +171,7 @@ class ChatService:
|
||||
self.chat_token = None
|
||||
|
||||
self.data = data
|
||||
self.conversation_id = self.data.get("conversation_id", None)
|
||||
self.model = self.data.get("model", "gpt-3.5-turbo-0125")
|
||||
self.api_messages = self.data.get("messages", [])
|
||||
self.prompt_tokens = num_tokens_from_messages(self.api_messages, self.model)
|
||||
@ -247,6 +262,8 @@ class ChatService:
|
||||
"force_nulligen": False,
|
||||
"force_rate_limit": False,
|
||||
}
|
||||
if self.conversation_id:
|
||||
self.chat_request["conversation_id"] = self.conversation_id
|
||||
return self.chat_request
|
||||
|
||||
async def send_conversation_for_stream(self):
|
||||
|
||||
@ -1,15 +1,17 @@
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from utils.config import authorization_list
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||
|
||||
|
||||
def verify_token(token: str = Depends(oauth2_scheme)):
|
||||
if not token:
|
||||
return None
|
||||
if not authorization_list or token in authorization_list:
|
||||
return token
|
||||
elif token.startswith("eyJhbGciOi"):
|
||||
return token
|
||||
else:
|
||||
return False
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user